Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
95b58beb
Commit
95b58beb
authored
Aug 14, 2025
by
helloyongyang
Browse files
update parallel
parent
f05a99da
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
246 additions
and
222 deletions
+246
-222
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-84
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+3
-22
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+12
-1
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+17
-0
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+2
-22
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+10
-4
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+52
-7
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+47
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+98
-71
lightx2v/models/networks/wan/weights/post_weights.py
lightx2v/models/networks/wan/weights/post_weights.py
+0
-11
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
No files found.
lightx2v/models/networks/wan/audio_model.py
View file @
95b58beb
import
glob
import
os
import
torch
from
lightx2v.common.ops.attn.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.model
import
WanModel
...
...
@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
...
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
95b58beb
...
...
@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
95b58beb
...
...
@@ -3,6 +3,7 @@ import torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
...
...
@@ -126,4 +127,14 @@ class WanAudioPreInfer(WanPreInfer):
del
context_clip
torch
.
cuda
.
empty_cache
()
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
x_grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
audio_dit_blocks
=
audio_dit_blocks
,
valid_patch_length
=
valid_patch_length
,
)
lightx2v/models/networks/wan/infer/module_io.py
0 → 100644
View file @
95b58beb
from
dataclasses
import
dataclass
from
typing
import
List
import
torch
@
dataclass
class
WanPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
audio_dit_blocks
:
List
=
None
valid_patch_length
:
int
=
None
lightx2v/models/networks/wan/infer/post_infer.py
View file @
95b58beb
...
...
@@ -10,35 +10,15 @@ class WanPostInfer:
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
95b58beb
...
...
@@ -2,6 +2,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
.module_io
import
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
...
...
@@ -132,8 +133,13 @@ class WanPreInfer:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
del
context_clip
torch
.
cuda
.
empty_cache
()
return
(
embed
,
grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
),
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
95b58beb
...
...
@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
compute_freqs_dist
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
...
@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
seq_p_group
=
None
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
...
...
@@ -86,15 +90,56 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
[
"seq_parallel"
]:
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_func
(
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
)
return
self
.
_infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
_infer_post_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
if
self
.
clean_cuda_cache
:
del
e
torch
.
cuda
.
empty_cache
()
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
95b58beb
import
torch
import
torch.distributed
as
dist
from
lightx2v.utils.envs
import
*
...
...
@@ -39,6 +40,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return
freqs_i
def
compute_freqs_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
...
...
lightx2v/models/networks/wan/model.py
View file @
95b58beb
...
...
@@ -7,7 +7,6 @@ from loguru import logger
from
safetensors
import
safe_open
from
lightx2v.common.ops.attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.dist_infer.transformer_infer
import
WanTransformerDistInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferAdaCaching
,
WanTransformerInferCustomCaching
,
...
...
@@ -83,27 +82,25 @@ class WanModel:
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer_class
=
WanTransformerDistInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_should_load_weights
(
self
):
"""Determine if current rank should load weights from disk."""
...
...
@@ -296,16 +293,7 @@ class WanModel:
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
,
self
.
seq_p_group
)
else
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
config
[
"cfg_parallel"
]:
self
.
infer_func
=
self
.
infer_with_cfg_parallel
else
:
self
.
infer_func
=
self
.
infer_wo_cfg_parallel
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -325,10 +313,6 @@ class WanModel:
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
return
self
.
infer_func
(
inputs
)
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
...
...
@@ -341,26 +325,31 @@ class WanModel:
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"cfg_parallel"
]:
# ==================== CFG Parallel Processing ====================
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
else
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
else
:
# ==================== CFG Processing ====================
noise_pred_cond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
noise_pred_uncond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
else
:
# ==================== No CFG ====================
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
...
...
@@ -370,24 +359,62 @@ class WanModel:
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_with_cfg_parallel
(
self
,
inputs
):
assert
self
.
config
[
"enable_cfg"
],
"enable_cfg must be True"
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
f
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
else
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
positive
)
if
self
.
config
[
"seq_parallel"
]:
pre_infer_out
=
self
.
_seq_parallel_pre_process
(
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
pre_infer_out
)
if
self
.
config
[
"seq_parallel"
]:
x
=
self
.
_seq_parallel_post_process
(
x
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
torch
.
cuda
.
empty_cache
()
return
noise_pred
@
torch
.
no_grad
()
def
_seq_parallel_pre_process
(
self
,
pre_infer_out
):
embed
,
x
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
].
startswith
(
"wan2.2"
):
padding_size
=
(
world_size
-
(
embed0
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
embed0
=
F
.
pad
(
embed0
,
(
0
,
0
,
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
padding_size
))
pre_infer_out
.
x
=
x
pre_infer_out
.
embed
=
embed
pre_infer_out
.
embed0
=
embed0
return
pre_infer_out
@
torch
.
no_grad
()
def
_seq_parallel_post_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
,
group
=
self
.
seq_p_group
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
return
combined_output
# 返回合并后的输出
lightx2v/models/networks/wan/weights/post_weights.py
View file @
95b58beb
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
(
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
TENSOR_REGISTER
,
)
class
WanPostWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
)
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
95b58beb
...
...
@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# post blocks weights
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
]())
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment