Commit 375a52d0 authored by gushiqiao's avatar gushiqiao
Browse files

Fix torch sdpa op

parent 8abfb2c6
...@@ -82,6 +82,12 @@ def get_available_attn_ops(): ...@@ -82,6 +82,12 @@ def get_available_attn_ops():
else: else:
available_ops.append(("sage_attn2", False)) available_ops.append(("sage_attn2", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops return available_ops
...@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else: else:
quant_type = "int8" quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"] attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"] quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority: for op in attn_priority:
......
...@@ -83,6 +83,12 @@ def get_available_attn_ops(): ...@@ -83,6 +83,12 @@ def get_available_attn_ops():
else: else:
available_ops.append(("sage_attn2", False)) available_ops.append(("sage_attn2", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops return available_ops
...@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else: else:
quant_type = "int8" quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"] attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"] quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority: for op in attn_priority:
......
...@@ -73,7 +73,18 @@ class FlashAttn2Weight(AttnWeightTemplate): ...@@ -73,7 +73,18 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None): def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
...@@ -91,7 +102,18 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -91,7 +102,18 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None): def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func_v3( x = flash_attn_varlen_func_v3(
q, q,
k, k,
...@@ -109,7 +131,21 @@ class RadialAttnWeight(AttnWeightTemplate): ...@@ -109,7 +131,21 @@ class RadialAttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"): def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
mask_map=None,
sparsity_type="radial",
block_size=128,
decay_factor=1,
model_cls="wan",
):
assert len(q.shape) == 3 assert len(q.shape) == 3
x = radial_attn( x = radial_attn(
...@@ -175,7 +211,22 @@ class TorchSDPAWeight(AttnWeightTemplate): ...@@ -175,7 +211,22 @@ class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, drop_rate=0, attn_mask=None, causal=False): def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
...@@ -185,12 +236,20 @@ class TorchSDPAWeight(AttnWeightTemplate): ...@@ -185,12 +236,20 @@ class TorchSDPAWeight(AttnWeightTemplate):
x = x.transpose(1, 2) x = x.transpose(1, 2)
b, s, a, d = x.shape b, s, a, d = x.shape
out = x.reshape(b, s, -1) out = x.reshape(b, s, -1)
return out return out.squeeze(0)
@ATTN_WEIGHT_REGISTER("Sparge") @ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate): class SpargeAttnWeight(AttnWeightTemplate):
def __init__(self, weight_name, verbose=False, l1=0.07, pv_l1=0.08, tune_pv=True, inner_attn_type="flash_attn3"): def __init__(
self,
weight_name,
verbose=False,
l1=0.07,
pv_l1=0.08,
tune_pv=True,
inner_attn_type="flash_attn3",
):
self.verbose = (verbose,) self.verbose = (verbose,)
self.l1 = (l1,) self.l1 = (l1,)
self.pv_l1 = (pv_l1,) self.pv_l1 = (pv_l1,)
...@@ -204,9 +263,23 @@ class SpargeAttnWeight(AttnWeightTemplate): ...@@ -204,9 +263,23 @@ class SpargeAttnWeight(AttnWeightTemplate):
for key in weight_dict.keys(): for key in weight_dict.keys():
if key.startswith(self.weight_name): if key.startswith(self.weight_name):
sub_name = key.split(".")[-1] sub_name = key.split(".")[-1]
setattr(self.inner_cls, sub_name, nn.Parameter(weight_dict[key], requires_grad=False)) setattr(
self.inner_cls,
sub_name,
nn.Parameter(weight_dict[key], requires_grad=False),
)
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3: if len(q.shape) == 3:
q = q.unsqueeze(0) q = q.unsqueeze(0)
k = k.unsqueeze(0) k = k.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment