Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f4b343f6
Commit
f4b343f6
authored
Apr 15, 2025
by
helloyongyang
Browse files
update sage_attn2
parent
1c4bd4d8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
21 deletions
+15
-21
lightx2v/attentions/common/sage_attn2.py
lightx2v/attentions/common/sage_attn2.py
+15
-21
No files found.
lightx2v/attentions/common/sage_attn2.py
View file @
f4b343f6
...
@@ -13,34 +13,28 @@ else:
...
@@ -13,34 +13,28 @@ else:
def
sage_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
"hunyuan"
):
def
sage_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
"hunyuan"
):
q
,
k
,
v
=
(
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
q
.
transpose
(
1
,
0
).
contiguous
(),
k
.
transpose
(
1
,
0
).
contiguous
(),
v
.
transpose
(
1
,
0
).
contiguous
(),
)
if
model_cls
==
"hunyuan"
:
if
model_cls
==
"hunyuan"
:
x1
=
sageattn
(
x1
=
sageattn
(
q
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
q
[:
cu_seqlens_q
[
1
]].
unsqueeze
(
0
),
k
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
k
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
v
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
v
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
)
x2
=
sageattn
(
x2
=
sageattn
(
q
[:,
cu_seqlens_q
[
1
]
:,
:].
unsqueeze
(
0
),
q
[
cu_seqlens_q
[
1
]
:].
unsqueeze
(
0
),
k
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
k
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
v
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
v
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
-
2
).
transpose
(
2
,
1
).
contiguous
(
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
==
"wan2.1"
:
elif
model_cls
==
"wan2.1"
:
x
=
(
x
=
sageattn
(
sageattn
(
q
.
unsqueeze
(
0
),
q
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
k
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
v
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
.
transpose
(
2
,
1
)
.
contiguous
()
)
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment