Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
95b58beb
Commit
95b58beb
authored
Aug 14, 2025
by
helloyongyang
Browse files
update parallel
parent
f05a99da
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
246 additions
and
222 deletions
+246
-222
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-84
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+3
-22
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+12
-1
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+17
-0
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+2
-22
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+10
-4
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+52
-7
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+47
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+98
-71
lightx2v/models/networks/wan/weights/post_weights.py
lightx2v/models/networks/wan/weights/post_weights.py
+0
-11
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
No files found.
lightx2v/models/networks/wan/audio_model.py
View file @
95b58beb
import
glob
import
glob
import
os
import
os
import
torch
from
lightx2v.common.ops.attn.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
...
@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
...
@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
post_infer_class
=
WanAudioPostInfer
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
class
Wan22MoeAudioModel
(
WanAudioModel
):
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
...
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
95b58beb
...
@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
if
e
.
dim
()
==
2
:
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
95b58beb
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
...
@@ -126,4 +127,14 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -126,4 +127,14 @@ class WanAudioPreInfer(WanPreInfer):
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
x_grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
audio_dit_blocks
=
audio_dit_blocks
,
valid_patch_length
=
valid_patch_length
,
)
lightx2v/models/networks/wan/infer/module_io.py
0 → 100644
View file @
95b58beb
from
dataclasses
import
dataclass
from
typing
import
List
import
torch
@
dataclass
class
WanPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
audio_dit_blocks
:
List
=
None
valid_patch_length
:
int
=
None
lightx2v/models/networks/wan/infer/post_infer.py
View file @
95b58beb
...
@@ -10,35 +10,15 @@ class WanPostInfer:
...
@@ -10,35 +10,15 @@ class WanPostInfer:
self
.
out_dim
=
config
[
"out_dim"
]
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
if
e
.
dim
()
==
2
:
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
return
[
u
.
float
()
for
u
in
x
]
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
95b58beb
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.module_io
import
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
...
@@ -132,8 +133,13 @@ class WanPreInfer:
...
@@ -132,8 +133,13 @@ class WanPreInfer:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
(
embed
,
return
WanPreInferModuleOutput
(
grid_sizes
,
embed
=
embed
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
),
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
)
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
95b58beb
...
@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
...
@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
compute_freqs_dist
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
seq_p_group
=
None
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
...
@@ -86,15 +90,56 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -86,15 +90,56 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
if
self
.
config
[
"seq_parallel"
]:
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
infer
(
self
,
weights
,
pre_infer_out
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
x
=
self
.
infer_func
(
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
)
return
self
.
_infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
_infer_post_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
x
=
weights
.
norm
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
if
self
.
clean_cuda_cache
:
del
e
torch
.
cuda
.
empty_cache
()
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
95b58beb
import
torch
import
torch
import
torch.distributed
as
dist
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -39,6 +40,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
...
@@ -39,6 +40,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return
freqs_i
return
freqs_i
def
compute_freqs_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
,
seq_p_group
):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
[
0
]
...
...
lightx2v/models/networks/wan/model.py
View file @
95b58beb
...
@@ -7,7 +7,6 @@ from loguru import logger
...
@@ -7,7 +7,6 @@ from loguru import logger
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
lightx2v.common.ops.attn
import
MaskMap
from
lightx2v.common.ops.attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.dist_infer.transformer_infer
import
WanTransformerDistInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferAdaCaching
,
WanTransformerInferAdaCaching
,
WanTransformerInferCustomCaching
,
WanTransformerInferCustomCaching
,
...
@@ -83,27 +82,25 @@ class WanModel:
...
@@ -83,27 +82,25 @@ class WanModel:
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer_class
=
WanTransformerDistInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
else
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_should_load_weights
(
self
):
def
_should_load_weights
(
self
):
"""Determine if current rank should load weights from disk."""
"""Determine if current rank should load weights from disk."""
...
@@ -296,16 +293,7 @@ class WanModel:
...
@@ -296,16 +293,7 @@ class WanModel:
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
seq_p_group
is
not
None
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
,
self
.
seq_p_group
)
else
:
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
config
[
"cfg_parallel"
]:
self
.
infer_func
=
self
.
infer_with_cfg_parallel
else
:
self
.
infer_func
=
self
.
infer_wo_cfg_parallel
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -325,10 +313,6 @@ class WanModel:
...
@@ -325,10 +313,6 @@ class WanModel:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
return
self
.
infer_func
(
inputs
)
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
self
.
to_cuda
()
...
@@ -341,26 +325,31 @@ class WanModel:
...
@@ -341,26 +325,31 @@ class WanModel:
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
if
self
.
config
[
"cfg_parallel"
]:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
# ==================== CFG Parallel Processing ====================
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
else
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
else
:
# ==================== CFG Processing ====================
noise_pred_cond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
noise_pred_uncond
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
False
)
if
self
.
clean_cuda_cache
:
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
else
:
torch
.
cuda
.
empty_cache
()
# ==================== No CFG ====================
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
positive
=
True
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
...
@@ -370,24 +359,62 @@ class WanModel:
...
@@ -370,24 +359,62 @@ class WanModel:
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer_with_cfg_parallel
(
self
,
inputs
):
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
assert
self
.
config
[
"enable_cfg"
],
"enable_cfg must be True"
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
positive
)
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
f
"cfg_p_world_size must be equal to 2"
if
self
.
config
[
"seq_parallel"
]:
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
pre_infer_out
=
self
.
_seq_parallel_pre_process
(
pre_infer_out
)
if
cfg_p_rank
==
0
:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
pre_infer_out
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
if
self
.
config
[
"seq_parallel"
]:
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
x
=
self
.
_seq_parallel_post_process
(
x
)
else
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
torch
.
cuda
.
empty_cache
()
return
noise_pred
@
torch
.
no_grad
()
def
_seq_parallel_pre_process
(
self
,
pre_infer_out
):
embed
,
x
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
].
startswith
(
"wan2.2"
):
padding_size
=
(
world_size
-
(
embed0
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
embed0
=
F
.
pad
(
embed0
,
(
0
,
0
,
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
padding_size
))
pre_infer_out
.
x
=
x
pre_infer_out
.
embed
=
embed
pre_infer_out
.
embed0
=
embed0
return
pre_infer_out
@
torch
.
no_grad
()
def
_seq_parallel_post_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
,
group
=
self
.
seq_p_group
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
# 在指定的维度上合并所有进程的输出
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
return
combined_output
# 返回合并后的输出
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
lightx2v/models/networks/wan/weights/post_weights.py
View file @
95b58beb
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
(
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
TENSOR_REGISTER
,
)
class
WanPostWeights
(
WeightModule
):
class
WanPostWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
)
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
95b58beb
...
@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
...
@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# post blocks weights
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
]())
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
clear
(
self
):
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
for
phase
in
block
.
compute_phases
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment