Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
1a881d63
Commit
1a881d63
authored
Jul 28, 2025
by
helloyongyang
Browse files
重构并行模块
parent
18e2b23a
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
100 additions
and
51 deletions
+100
-51
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+1
-1
lightx2v/models/networks/wan/infer/dist_infer/__init__.py
lightx2v/models/networks/wan/infer/dist_infer/__init__.py
+0
-0
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
...models/networks/wan/infer/dist_infer/transformer_infer.py
+51
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+18
-20
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+22
-28
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+4
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+4
-2
No files found.
lightx2v/models/networks/wan/causvid_model.py
View file @
1a881d63
import
os
import
os
import
torch
import
torch
from
lightx2v.
attentions.commo
n.radial_attn
import
MaskMap
from
lightx2v.
common.ops.att
n.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
...
...
lightx2v/
attentions/common
/__init__.py
→
lightx2v/
models/networks/wan/infer/dist_infer
/__init__.py
View file @
1a881d63
File moved
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
0 → 100755
View file @
1a881d63
import
torch
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
lightx2v.models.networks.wan.infer.utils
import
compute_freqs_dist
,
compute_freqs_audio_dist
class
WanTransformerDistInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
x
=
self
.
dist_pre_process
(
x
)
x
=
super
().
infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
x
=
self
.
dist_post_process
(
x
)
return
x
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
dist_pre_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
return
x
def
dist_post_process
(
self
,
x
):
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
return
combined_output
# 返回合并后的输出
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
1a881d63
...
@@ -318,6 +318,13 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -318,6 +318,13 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm1_weight
.
tensor
...
@@ -342,16 +349,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -342,16 +349,7 @@ class WanTransformerInfer(BaseTransformerInfer):
k
=
weights
.
self_attn_norm_k
.
apply
(
weights
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
weights
.
self_attn_norm_k
.
apply
(
weights
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
if
not
self
.
parallel_attention
:
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
...
@@ -365,7 +363,16 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -365,7 +363,16 @@ class WanTransformerInfer(BaseTransformerInfer):
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
not
self
.
parallel_attention
:
if
self
.
config
.
get
(
"parallel_attn_type"
,
None
):
attn_out
=
weights
.
self_attn_1_parallel
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
attention_module
=
weights
.
self_attn_1
,
)
else
:
attn_out
=
weights
.
self_attn_1
.
apply
(
attn_out
=
weights
.
self_attn_1
.
apply
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -377,15 +384,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -377,15 +384,6 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
mask_map
=
self
.
mask_map
,
)
)
else
:
attn_out
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
...
...
lightx2v/models/networks/wan/model.py
View file @
1a881d63
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
torch
import
torch
import
glob
import
glob
import
json
import
json
from
lightx2v.
attentions.common.radial_
attn
import
MaskMap
from
lightx2v.
common.ops.
attn
import
MaskMap
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
...
@@ -22,9 +22,8 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
...
@@ -22,9 +22,8 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferDualBlock
,
WanTransformerInferDualBlock
,
WanTransformerInferDynamicBlock
,
WanTransformerInferDynamicBlock
,
)
)
from
lightx2v.models.networks.wan.infer.dist_infer.transformer_infer
import
WanTransformerDistInfer
from
safetensors
import
safe_open
from
safetensors
import
safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
loguru
import
logger
from
loguru
import
logger
...
@@ -58,35 +57,30 @@ class WanModel:
...
@@ -58,35 +57,30 @@ class WanModel:
self
.
_init_weights
()
self
.
_init_weights
()
self
.
_init_infer
()
self
.
_init_infer
()
if
config
[
"parallel_attn_type"
]:
if
config
[
"parallel_attn_type"
]
==
"ulysses"
:
ulysses_dist_wrap
.
parallelize_wan
(
self
)
elif
config
[
"parallel_attn_type"
]
==
"ring"
:
ring_dist_wrap
.
parallelize_wan
(
self
)
else
:
raise
Exception
(
f
"Unsuppotred parallel_attn_type"
)
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
if
self
.
config
.
get
(
"parallel_attn_type"
,
None
):
self
.
transformer_infer_class
=
WanTransformerInfer
self
.
transformer_infer_class
=
WanTransformerDistInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
elif
self
.
config
[
"feature_caching"
]
==
"FirstBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferFirstBlock
elif
self
.
config
[
"feature_caching"
]
==
"DualBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDualBlock
elif
self
.
config
[
"feature_caching"
]
==
"DynamicBlock"
:
self
.
transformer_infer_class
=
WanTransformerInferDynamicBlock
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_safetensor_to_dict
(
self
,
file_path
,
use_bf16
,
skip_bf16
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
use_bf16
,
skip_bf16
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
1a881d63
...
@@ -190,6 +190,10 @@ class WanSelfAttention(WeightModule):
...
@@ -190,6 +190,10 @@ class WanSelfAttention(WeightModule):
self
.
self_attn_1
.
load
(
sparge_ckpt
)
self
.
self_attn_1
.
load
(
sparge_ckpt
)
else
:
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
if
self
.
config
.
get
(
"parallel_attn_type"
,
None
):
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"parallel_attn_type"
]]())
if
self
.
quant_method
in
[
"advanced_ptq"
]:
if
self
.
quant_method
in
[
"advanced_ptq"
]:
self
.
add_module
(
self
.
add_module
(
"smooth_norm1_weight"
,
"smooth_norm1_weight"
,
...
...
lightx2v/models/runners/default_runner.py
View file @
1a881d63
...
@@ -235,8 +235,10 @@ class DefaultRunner(BaseRunner):
...
@@ -235,8 +235,10 @@ class DefaultRunner(BaseRunner):
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
else
:
else
:
fps
=
self
.
config
.
get
(
"fps"
,
16
)
fps
=
self
.
config
.
get
(
"fps"
,
16
)
logger
.
info
(
f
"Saving video to
{
self
.
config
.
save_video_path
}
"
)
save_to_video
(
images
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
# type: ignore
if
not
self
.
config
.
get
(
"parallel_attn_type"
,
None
)
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Saving video to
{
self
.
config
.
save_video_path
}
"
)
save_to_video
(
images
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
# type: ignore
del
latents
,
generator
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment