Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b4496e64
Commit
b4496e64
authored
Jul 08, 2025
by
gushiqiao
Committed by
GitHub
Jul 08, 2025
Browse files
Merge pull request #93 from ModelTC/dev_fix
Dev fix
parents
53eae786
a4b666ca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
15 deletions
+87
-15
app/gradio_demo.py
app/gradio_demo.py
+1
-1
app/gradio_demo_zh.py
app/gradio_demo_zh.py
+1
-1
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+44
-12
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+40
-0
No files found.
app/gradio_demo.py
View file @
b4496e64
...
...
@@ -1022,4 +1022,4 @@ if __name__ == "__main__":
model_path
=
args
.
model_path
model_cls
=
args
.
model_cls
main
()
\ No newline at end of file
main
()
app/gradio_demo_zh.py
View file @
b4496e64
...
...
@@ -1022,4 +1022,4 @@ if __name__ == "__main__":
model_path
=
args
.
model_path
model_cls
=
args
.
model_cls
main
()
\ No newline at end of file
main
()
lightx2v/common/offload/manager.py
View file @
b4496e64
...
...
@@ -360,4 +360,4 @@ class MemoryBuffer:
self
.
insertion_index
=
0
self
.
used_mem
=
0
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
\ No newline at end of file
gc
.
collect
()
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
b4496e64
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
lightx2v.common.offload.manager
import
(
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
...
...
@@ -26,6 +26,8 @@ class WanTransformerInfer(BaseTransformerInfer):
else
:
self
.
apply_rotary_emb_func
=
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
mask_map
=
None
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
...
...
@@ -73,10 +75,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
...
...
@@ -138,7 +140,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
...
...
@@ -186,7 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
for
block_idx
in
range
(
weights
.
blocks_num
):
...
...
@@ -247,7 +249,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
zero_temporal_component_in_3DRoPE
(
self
,
valid_token_length
,
rotary_emb
=
None
):
if
rotary_emb
is
None
:
return
None
self
.
use_real
=
False
rope_t_dim
=
44
if
self
.
use_real
:
freqs_cos
,
freqs_sin
=
rotary_emb
freqs_cos
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
freqs_sin
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
rotary_emb
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
...
...
@@ -259,6 +276,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
context
,
)
if
audio_dit_blocks
is
not
None
and
len
(
audio_dit_blocks
)
>
0
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
...
@@ -318,14 +341,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
if
not
self
.
parallel_attention
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
k_lens
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
...
...
@@ -341,6 +373,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
)
else
:
attn_out
=
self
.
parallel_attention
(
...
...
@@ -406,7 +439,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
img_attn_out
=
weights
.
cross_attn_2
.
apply
(
q
=
q
,
k
=
k_img
,
...
...
@@ -471,4 +503,4 @@ class WanTransformerInfer(BaseTransformerInfer):
if
self
.
clean_cuda_cache
:
del
y
,
c_gate_msa
torch
.
cuda
.
empty_cache
()
return
x
\ No newline at end of file
return
x
lightx2v/models/networks/wan/infer/utils.py
View file @
b4496e64
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
...
...
@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return
freqs_i
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
##for r2v add 1 channel
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
return
freqs_i
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment