Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
95b58beb
"model/models/vscode:/vscode.git/clone" did not exist on "fa7776fd2458fc3a8aeb7f12e4bc65b439955319"
Commit
95b58beb
authored
Aug 14, 2025
by
helloyongyang
Browse files
update parallel
parent
f05a99da
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
246 additions
and
222 deletions
+246
-222
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-84
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+3
-22
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+12
-1
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+17
-0
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+2
-22
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+10
-4
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+52
-7
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+47
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+98
-71
lightx2v/models/networks/wan/weights/post_weights.py
lightx2v/models/networks/wan/weights/post_weights.py
+0
-11
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
No files found.
lightx2v/models/networks/wan/audio_model.py
View file @
95b58beb
import
glob
import
os
import
torch
from
lightx2v.common.ops.attn.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.model
import
WanModel
...
...
@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
...
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
95b58beb
...
...
@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
95b58beb
...
...
@@ -3,6 +3,7 @@ import torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
...
...
@@ -126,4 +127,14 @@ class WanAudioPreInfer(WanPreInfer):
del
context_clip
torch
.
cuda
.
empty_cache
()
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
x_grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
audio_dit_blocks
=
audio_dit_blocks
,
valid_patch_length
=
valid_patch_length
,
)
lightx2v/models/networks/wan/infer/module_io.py
0 → 100644
View file @
95b58beb
from
dataclasses
import
dataclass
from
typing
import
List
import
torch
@
dataclass
class
WanPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
audio_dit_blocks
:
List
=
None
valid_patch_length
:
int
=
None
lightx2v/models/networks/wan/infer/post_infer.py
View file @
95b58beb
...
...
@@ -10,35 +10,15 @@ class WanPostInfer:
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
95b58beb
...
...
@@ -2,6 +2,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
.module_io
import
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
...
...
@@ -132,8 +133,13 @@ class WanPreInfer:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
del
context_clip
torch
.
cuda
.
empty_cache
()
return
(
embed
,
grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
),
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
95b58beb
...
...
@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
compute_freqs_dist
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
...
@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
seq_p_group
=
None
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
...
...
@@ -86,15 +90,56 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
[
"seq_parallel"
]:
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_func
(
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
)
return
self
.
_infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
_infer_post_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
if
self
.
clean_cuda_cache
:
del
e
torch
.
cuda
.
empty_cache
()
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
95b58beb
import
torch
import
torch.distributed
as
dist
from
lightx2v.utils.envs
import
*
...
...
@@ -39,6 +40,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return
freqs_i
def
compute_freqs_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
...
...
lightx2v/models/networks/wan/model.py
View file @
95b58beb
...
...
@@ -7,7 +7,6 @@ from loguru import logger
from
safetensors
import
safe_open
from
lightx2v.common.ops.attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.dist_infer.transformer_infer
import
WanTransformerDistInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferAdaCaching
,
WanTransformerInferCustomCaching
,
...
...
@@ -83,27 +82,25 @@ class WanModel:
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer_class
=
WanTransformerDistInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_should_load_weights
(
self
):
"""Determine if current rank should load weights from disk."""
...
...
@@ -296,16 +293,7 @@ class WanModel:
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
,
self
.
seq_p_group
)
else
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
config
[
"cfg_parallel"
]:
self
.
infer_func
=
self
.
infer_with_cfg_parallel
else
:
self
.
infer_func
=
self
.
infer_wo_cfg_parallel
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -325,10 +313,6 @@ class WanModel:
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
return
self
.
infer_func
(
inputs
)
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
...
...
@@ -341,26 +325,31 @@ class WanModel:
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"cfg_parallel"
]:
# ==================== CFG Parallel Processing ====================
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
else
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
else
:
# ==================== CFG Processing ====================
noise_pred_cond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
noise_pred_uncond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
else
:
# ==================== No CFG ====================
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
...
...
@@ -370,24 +359,62 @@ class WanModel:
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_with_cfg_parallel
(
self
,
inputs
):
assert
self
.
config
[
"enable_cfg"
],
"enable_cfg must be True"
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
f
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
else
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
positive
)
if
self
.
config
[
"seq_parallel"
]:
pre_infer_out
=
self
.
_seq_parallel_pre_process
(
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
pre_infer_out
)
if
self
.
config
[
"seq_parallel"
]:
x
=
self
.
_seq_parallel_post_process
(
x
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
torch
.
cuda
.
empty_cache
()
return
noise_pred
@
torch
.
no_grad
()
def
_seq_parallel_pre_process
(
self
,
pre_infer_out
):
embed
,
x
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
].
startswith
(
"wan2.2"
):
padding_size
=
(
world_size
-
(
embed0
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
embed0
=
F
.
pad
(
embed0
,
(
0
,
0
,
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
padding_size
))
pre_infer_out
.
x
=
x
pre_infer_out
.
embed
=
embed
pre_infer_out
.
embed0
=
embed0
return
pre_infer_out
@
torch
.
no_grad
()
def
_seq_parallel_post_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
,
group
=
self
.
seq_p_group
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
return
combined_output
# 返回合并后的输出
lightx2v/models/networks/wan/weights/post_weights.py
View file @
95b58beb
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
(
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
TENSOR_REGISTER
,
)
class
WanPostWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
)
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
95b58beb
...
...
@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# post blocks weights
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
]())
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment