Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
62d8881a
Commit
62d8881a
authored
Jul 10, 2025
by
gushiqiao
Committed by
GitHub
Jul 10, 2025
Browse files
Merge pull request #99 from ModelTC/dev_fixbugs
Fix torch sdpa op
parents
8abfb2c6
375a52d0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
10 deletions
+95
-10
app/gradio_demo.py
app/gradio_demo.py
+7
-1
app/gradio_demo_zh.py
app/gradio_demo_zh.py
+7
-1
lightx2v/common/ops/attn/attn_weight.py
lightx2v/common/ops/attn/attn_weight.py
+81
-8
No files found.
app/gradio_demo.py
View file @
62d8881a
...
...
@@ -82,6 +82,12 @@ def get_available_attn_ops():
else
:
available_ops
.
append
((
"sage_attn2"
,
False
))
torch_installed
=
is_module_installed
(
"torch"
)
if
torch_installed
:
available_ops
.
append
((
"torch_sdpa"
,
True
))
else
:
available_ops
.
append
((
"torch_sdpa"
,
False
))
return
available_ops
...
...
@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else
:
quant_type
=
"int8"
attn_priority
=
[
"sage_attn2"
,
"flash_attn3"
,
"flash_attn2"
]
attn_priority
=
[
"sage_attn2"
,
"flash_attn3"
,
"flash_attn2"
,
"torch_sdpa"
]
quant_op_priority
=
[
"sgl"
,
"vllm"
,
"q8f"
]
for
op
in
attn_priority
:
...
...
app/gradio_demo_zh.py
View file @
62d8881a
...
...
@@ -83,6 +83,12 @@ def get_available_attn_ops():
else
:
available_ops
.
append
((
"sage_attn2"
,
False
))
torch_installed
=
is_module_installed
(
"torch"
)
if
torch_installed
:
available_ops
.
append
((
"torch_sdpa"
,
True
))
else
:
available_ops
.
append
((
"torch_sdpa"
,
False
))
return
available_ops
...
...
@@ -468,7 +474,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else
:
quant_type
=
"int8"
attn_priority
=
[
"sage_attn2"
,
"flash_attn3"
,
"flash_attn2"
]
attn_priority
=
[
"sage_attn2"
,
"flash_attn3"
,
"flash_attn2"
,
"torch_sdpa"
]
quant_op_priority
=
[
"sgl"
,
"vllm"
,
"q8f"
]
for
op
in
attn_priority
:
...
...
lightx2v/common/ops/attn/attn_weight.py
View file @
62d8881a
...
...
@@ -73,7 +73,18 @@ class FlashAttn2Weight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func
(
q
,
k
,
...
...
@@ -91,7 +102,18 @@ class FlashAttn3Weight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
...
...
@@ -109,7 +131,21 @@ class RadialAttnWeight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
,
):
assert
len
(
q
.
shape
)
==
3
x
=
radial_attn
(
...
...
@@ -175,7 +211,22 @@ class TorchSDPAWeight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
):
def
apply
(
self
,
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
...
...
@@ -185,12 +236,20 @@ class TorchSDPAWeight(AttnWeightTemplate):
x
=
x
.
transpose
(
1
,
2
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
return
out
.
squeeze
(
0
)
@
ATTN_WEIGHT_REGISTER
(
"Sparge"
)
class
SpargeAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
,
weight_name
,
verbose
=
False
,
l1
=
0.07
,
pv_l1
=
0.08
,
tune_pv
=
True
,
inner_attn_type
=
"flash_attn3"
):
def
__init__
(
self
,
weight_name
,
verbose
=
False
,
l1
=
0.07
,
pv_l1
=
0.08
,
tune_pv
=
True
,
inner_attn_type
=
"flash_attn3"
,
):
self
.
verbose
=
(
verbose
,)
self
.
l1
=
(
l1
,)
self
.
pv_l1
=
(
pv_l1
,)
...
...
@@ -204,9 +263,23 @@ class SpargeAttnWeight(AttnWeightTemplate):
for
key
in
weight_dict
.
keys
():
if
key
.
startswith
(
self
.
weight_name
):
sub_name
=
key
.
split
(
"."
)[
-
1
]
setattr
(
self
.
inner_cls
,
sub_name
,
nn
.
Parameter
(
weight_dict
[
key
],
requires_grad
=
False
))
setattr
(
self
.
inner_cls
,
sub_name
,
nn
.
Parameter
(
weight_dict
[
key
],
requires_grad
=
False
),
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
if
len
(
q
.
shape
)
==
3
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment