Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9a686a73
Commit
9a686a73
authored
Apr 09, 2025
by
gushiqiao
Committed by
GitHub
Apr 09, 2025
Browse files
Support wan2.1 sageattn and fix oom for 720P. (#12)
parent
a951c882
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
55 deletions
+78
-55
lightx2v/__main__.py
lightx2v/__main__.py
+12
-4
lightx2v/attentions/common/sage_attn2.py
lightx2v/attentions/common/sage_attn2.py
+36
-25
lightx2v/text2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+9
-2
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
...x2v/text2v/models/networks/wan/infer/transformer_infer.py
+3
-24
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
...text2v/models/schedulers/wan/feature_caching/scheduler.py
+15
-0
lightx2v/text2v/models/schedulers/wan/scheduler.py
lightx2v/text2v/models/schedulers/wan/scheduler.py
+3
-0
No files found.
lightx2v/__main__.py
View file @
9a686a73
...
...
@@ -328,6 +328,7 @@ if __name__ == "__main__":
mm_config
=
None
model_config
=
{
"model_cls"
:
args
.
model_cls
,
"task"
:
args
.
task
,
"attention_type"
:
args
.
attention_type
,
"sample_neg_prompt"
:
args
.
sample_neg_prompt
,
...
...
@@ -348,6 +349,9 @@ if __name__ == "__main__":
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
args
,
model_config
)
load_models_time
=
time
.
time
()
print
(
f
"Load models cost:
{
load_models_time
-
start_time
}
"
)
if
args
.
task
in
[
"i2v"
]:
image_encoder_output
=
run_image_encoder
(
args
,
image_encoder
,
vae_model
)
else
:
...
...
@@ -362,19 +366,23 @@ if __name__ == "__main__":
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
latents
,
generator
=
run_main_inference
(
args
,
model
,
text_encoder_output
,
image_encoder_output
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
if
args
.
cpu_offload
:
scheduler
.
clear
()
del
text_encoder_output
,
image_encoder_output
,
model
,
text_encoders
,
scheduler
torch
.
cuda
.
empty_cache
()
images
=
run_vae
(
latents
,
generator
,
args
)
if
not
args
.
parallel_attn_type
or
(
args
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
save_video_st
=
time
.
time
()
if
args
.
model_cls
==
"wan2.1"
:
cache_video
(
tensor
=
images
,
save_file
=
args
.
save_video_path
,
fps
=
16
,
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
))
else
:
save_videos_grid
(
images
,
args
.
save_video_path
,
fps
=
24
)
save_video_et
=
time
.
time
()
print
(
f
"Save video cost:
{
save_video_et
-
save_video_st
}
"
)
end_time
=
time
.
time
()
print
(
f
"Total
time
:
{
end_time
-
start_time
}
"
)
print
(
f
"Total
cost
:
{
end_time
-
start_time
}
"
)
lightx2v/attentions/common/sage_attn2.py
View file @
9a686a73
import
torch
try
:
from
sageattention
import
sageattn
except
ImportError
:
sageattn
=
None
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
8
,
9
):
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
sageattn
=
None
,
None
else
:
try
:
from
sageattention
import
sageattn
except
ImportError
:
sageattn
=
None
def
sage_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
):
def
sage_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
"hunyuan"
):
q
,
k
,
v
=
(
q
.
transpose
(
1
,
0
).
contiguous
(),
k
.
transpose
(
1
,
0
).
contiguous
(),
v
.
transpose
(
1
,
0
).
contiguous
(),
)
x1
=
sageattn
(
q
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
k
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
v
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
)
x2
=
sageattn
(
q
[:,
cu_seqlens_q
[
1
]
:,
:].
unsqueeze
(
0
),
k
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
v
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=-
2
).
transpose
(
2
,
1
).
contiguous
()
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
if
model_cls
==
"hunyuan"
:
x1
=
sageattn
(
q
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
k
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
v
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
)
x2
=
sageattn
(
q
[:,
cu_seqlens_q
[
1
]
:,
:].
unsqueeze
(
0
),
k
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
v
[:,
cu_seqlens_kv
[
1
]
:,
:].
unsqueeze
(
0
),
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=-
2
).
transpose
(
2
,
1
).
contiguous
()
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
==
"wan2.1"
:
x
=
(
sageattn
(
q
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
k
[:,
:
cu_seqlens_q
[
1
],
:].
unsqueeze
(
0
),
v
[:,
:
cu_seqlens_kv
[
1
],
:].
unsqueeze
(
0
),
)
.
transpose
(
2
,
1
)
.
contiguous
()
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/text2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
9a686a73
import
numpy
as
np
from
..transformer_infer
import
WanTransformerInfer
from
lightx2v.attentions
import
attention
import
torch
class
WanTransformerInferFeatureCaching
(
WanTransformerInfer
):
...
...
@@ -61,6 +61,10 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context
,
)
self
.
scheduler
.
previous_residual_even
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
ori_x
=
ori_x
.
to
(
"cpu"
)
del
ori_x
torch
.
cuda
.
empty_cache
()
else
:
if
not
should_calc_odd
:
x
+=
self
.
scheduler
.
previous_residual_odd
...
...
@@ -77,5 +81,8 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context
,
)
self
.
scheduler
.
previous_residual_odd
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
ori_x
=
ori_x
.
to
(
"cpu"
)
del
ori_x
torch
.
cuda
.
empty_cache
()
return
x
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
View file @
9a686a73
...
...
@@ -98,14 +98,7 @@ class WanTransformerInfer:
if
not
self
.
parallel_attention
:
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
model_cls
=
self
.
config
[
"model_cls"
]
)
else
:
attn_out
=
self
.
parallel_attention
(
...
...
@@ -136,14 +129,7 @@ class WanTransformerInfer:
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
model_cls
=
self
.
config
[
"model_cls"
]
)
if
self
.
task
==
"i2v"
:
...
...
@@ -157,14 +143,7 @@ class WanTransformerInfer:
)
img_attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k_img
,
v
=
v_img
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k_img
,
v
=
v_img
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
model_cls
=
self
.
config
[
"model_cls"
]
)
attn_out
=
attn_out
+
img_attn_out
...
...
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
View file @
9a686a73
...
...
@@ -71,3 +71,18 @@ class WanSchedulerFeatureCaching(WanScheduler):
self
.
coefficients
=
[
-
5784.54975374
,
5449.50911966
,
-
1811.16591783
,
256.27178429
,
-
13.02252404
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
-
2
def
clear
(
self
):
if
self
.
previous_e0_even
is
not
None
:
self
.
previous_e0_even
=
self
.
previous_e0_even
.
cpu
()
if
self
.
previous_e0_odd
is
not
None
:
self
.
previous_e0_odd
=
self
.
previous_e0_odd
.
cpu
()
if
self
.
previous_residual_even
is
not
None
:
self
.
previous_residual_even
=
self
.
previous_residual_even
.
cpu
()
if
self
.
previous_residual_odd
is
not
None
:
self
.
previous_residual_odd
=
self
.
previous_residual_odd
.
cpu
()
self
.
previous_e0_even
=
None
self
.
previous_e0_odd
=
None
self
.
previous_residual_even
=
None
self
.
previous_residual_odd
=
None
torch
.
cuda
.
empty_cache
()
lightx2v/text2v/models/schedulers/wan/scheduler.py
View file @
9a686a73
...
...
@@ -341,3 +341,6 @@ class WanScheduler(BaseScheduler):
self
.
lower_order_nums
+=
1
self
.
latents
=
prev_sample
def
clear
(
self
):
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment