Commit 62d8881a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #99 from ModelTC/dev_fixbugs

Fix torch sdpa op
parents 8abfb2c6 375a52d0
......@@ -82,6 +82,12 @@ def get_available_attn_ops():
else:
available_ops.append(("sage_attn2", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops
......@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"]
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
......
......@@ -83,6 +83,12 @@ def get_available_attn_ops():
else:
available_ops.append(("sage_attn2", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops
......@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"]
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
......
......@@ -73,7 +73,18 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func(
q,
k,
......@@ -91,7 +102,18 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func_v3(
q,
k,
......@@ -109,7 +131,21 @@ class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"):
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
mask_map=None,
sparsity_type="radial",
block_size=128,
decay_factor=1,
model_cls="wan",
):
assert len(q.shape) == 3
x = radial_attn(
......@@ -175,7 +211,22 @@ class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, drop_rate=0, attn_mask=None, causal=False):
def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
......@@ -185,12 +236,20 @@ class TorchSDPAWeight(AttnWeightTemplate):
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
return out.squeeze(0)
@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
def __init__(self, weight_name, verbose=False, l1=0.07, pv_l1=0.08, tune_pv=True, inner_attn_type="flash_attn3"):
def __init__(
self,
weight_name,
verbose=False,
l1=0.07,
pv_l1=0.08,
tune_pv=True,
inner_attn_type="flash_attn3",
):
self.verbose = (verbose,)
self.l1 = (l1,)
self.pv_l1 = (pv_l1,)
......@@ -204,9 +263,23 @@ class SpargeAttnWeight(AttnWeightTemplate):
for key in weight_dict.keys():
if key.startswith(self.weight_name):
sub_name = key.split(".")[-1]
setattr(self.inner_cls, sub_name, nn.Parameter(weight_dict[key], requires_grad=False))
setattr(
self.inner_cls,
sub_name,
nn.Parameter(weight_dict[key], requires_grad=False),
)
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment