Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
0dd7ca09
Commit
0dd7ca09
authored
Aug 21, 2025
by
gushiqiao
Committed by
GitHub
Aug 21, 2025
Browse files
[Fix] Fix sage-attn distribute bug (#235)
parent
79c3caa2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
2 deletions
+3
-2
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-0
No files found.
lightx2v/common/ops/attn/ulysses_attn.py
100644 → 100755
View file @
0dd7ca09
...
...
@@ -12,7 +12,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
,
model_cls
=
None
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
...
...
@@ -77,7 +77,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn
=
attention_module
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
)
attn
=
attention_module
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
model_cls
=
model_cls
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],
:],
attn
[
img_q
.
shape
[
0
]
:,]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
0dd7ca09
...
...
@@ -164,6 +164,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv
=
cu_seqlens_q
,
attention_module
=
weights
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
)
else
:
attn_out
=
weights
.
self_attn_1
.
apply
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment