Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
60f0b6c1
Commit
60f0b6c1
authored
Aug 01, 2025
by
helloyongyang
Browse files
Support cfg parallel & hybrid parallel (cfg + seq)
parent
a395cc0a
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
89 deletions
+207
-89
configs/dist_infer/wan_i2v_dist_cfg_ulysses.json
configs/dist_infer/wan_i2v_dist_cfg_ulysses.json
+19
-0
configs/dist_infer/wan_i2v_dist_ring.json
configs/dist_infer/wan_i2v_dist_ring.json
+4
-2
configs/dist_infer/wan_i2v_dist_ulysses.json
configs/dist_infer/wan_i2v_dist_ulysses.json
+4
-2
configs/dist_infer/wan_t2v_dist_cfg.json
configs/dist_infer/wan_t2v_dist_cfg.json
+18
-0
configs/dist_infer/wan_t2v_dist_cfg_ring.json
configs/dist_infer/wan_t2v_dist_cfg_ring.json
+20
-0
configs/dist_infer/wan_t2v_dist_cfg_ulysses.json
configs/dist_infer/wan_t2v_dist_cfg_ulysses.json
+20
-0
configs/dist_infer/wan_t2v_dist_ring.json
configs/dist_infer/wan_t2v_dist_ring.json
+4
-2
configs/dist_infer/wan_t2v_dist_ulysses.json
configs/dist_infer/wan_t2v_dist_ulysses.json
+4
-2
lightx2v/common/ops/attn/ring_attn.py
lightx2v/common/ops/attn/ring_attn.py
+4
-4
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+8
-8
lightx2v/common/ops/attn/utils/all2all.py
lightx2v/common/ops/attn/utils/all2all.py
+6
-6
lightx2v/infer.py
lightx2v/infer.py
+7
-1
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
...models/networks/wan/infer/dist_infer/transformer_infer.py
+49
-8
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+3
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+0
-43
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+30
-2
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-2
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-2
No files found.
configs/dist_infer/wan_i2v_dist_cfg_ulysses.json
0 → 100644
View file @
60f0b6c1
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
,
"cfg_p_size"
:
2
}
}
configs/dist_infer/wan_i2v_dist_ring.json
View file @
60f0b6c1
...
...
@@ -11,6 +11,8 @@
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel_attn_type"
:
"ring"
,
"parallel_vae"
:
true
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ring"
}
}
configs/dist_infer/wan_i2v_dist_ulysses.json
View file @
60f0b6c1
...
...
@@ -11,6 +11,8 @@
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel_attn_type"
:
"ulysses"
,
"parallel_vae"
:
true
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
}
}
configs/dist_infer/wan_t2v_dist_cfg.json
0 → 100755
View file @
60f0b6c1
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel"
:
{
"cfg_p_size"
:
2
}
}
configs/dist_infer/wan_t2v_dist_cfg_ring.json
0 → 100755
View file @
60f0b6c1
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ring"
,
"cfg_p_size"
:
2
}
}
configs/dist_infer/wan_t2v_dist_cfg_ulysses.json
0 → 100755
View file @
60f0b6c1
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
,
"cfg_p_size"
:
2
}
}
configs/dist_infer/wan_t2v_dist_ring.json
View file @
60f0b6c1
...
...
@@ -12,6 +12,8 @@
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel_attn_type"
:
"ring"
,
"parallel_vae"
:
true
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ring"
}
}
configs/dist_infer/wan_t2v_dist_ulysses.json
View file @
60f0b6c1
...
...
@@ -12,6 +12,8 @@
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"parallel_attn_type"
:
"ulysses"
,
"parallel_vae"
:
true
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
}
}
lightx2v/common/ops/attn/ring_attn.py
View file @
60f0b6c1
...
...
@@ -38,7 +38,7 @@ class RingAttnWeight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
...
...
@@ -54,8 +54,8 @@ class RingAttnWeight(AttnWeightTemplate):
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
world_size
=
dist
.
get_world_size
(
seq_p_group
)
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
...
...
@@ -67,7 +67,7 @@ class RingAttnWeight(AttnWeightTemplate):
# if RING_COMM is None:
# init_ring_comm()
RING_COMM
=
RingComm
()
RING_COMM
=
RingComm
(
seq_p_group
)
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
...
...
lightx2v/common/ops/attn/ulysses_attn.py
View file @
60f0b6c1
...
...
@@ -10,7 +10,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
...
...
@@ -26,8 +26,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_
rank
(
)
world_size
=
dist
.
get_
world_size
(
)
world_size
=
dist
.
get_
world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_
rank
(
seq_p_group
)
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
...
...
@@ -48,9 +48,9 @@ class UlyssesAttnWeight(AttnWeightTemplate):
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,
:,
:].
contiguous
(),
k
[
img_qkv_len
:,
:,
:].
contiguous
(),
v
[
img_qkv_len
:,
:,
:].
contiguous
()
# 将图像的查询、键和值转换为头的格式
img_q
=
all2all_seq2head
(
img_q
)
img_k
=
all2all_seq2head
(
img_k
)
img_v
=
all2all_seq2head
(
img_v
)
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
...
...
@@ -82,11 +82,11 @@ class UlyssesAttnWeight(AttnWeightTemplate):
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
)
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
,
group
=
seq_p_group
)
# 处理图像注意力结果
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
)
# 将头的格式转换回序列格式
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
...
...
lightx2v/common/ops/attn/utils/all2all.py
View file @
60f0b6c1
...
...
@@ -4,7 +4,7 @@ import torch.distributed as dist
@
dynamo
.
disable
def
all2all_seq2head
(
input
):
def
all2all_seq2head
(
input
,
group
=
None
):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
...
...
@@ -18,7 +18,7 @@ def all2all_seq2head(input):
assert
input
.
dim
()
==
3
,
f
"input must be 3D tensor"
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
(
group
=
group
)
# 获取输入张量的形状
shard_seq_len
,
heads
,
hidden_dims
=
input
.
shape
...
...
@@ -36,7 +36,7 @@ def all2all_seq2head(input):
output
=
torch
.
empty_like
(
input_t
)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist
.
all_to_all_single
(
output
,
input_t
)
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output
=
output
.
reshape
(
seq_len
,
shard_heads
,
hidden_dims
).
contiguous
()
...
...
@@ -45,7 +45,7 @@ def all2all_seq2head(input):
@
dynamo
.
disable
def
all2all_head2seq
(
input
):
def
all2all_head2seq
(
input
,
group
=
None
):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
...
...
@@ -59,7 +59,7 @@ def all2all_head2seq(input):
assert
input
.
dim
()
==
3
,
f
"input must be 3D tensor"
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
(
group
=
group
)
# 获取输入张量的形状
seq_len
,
shard_heads
,
hidden_dims
=
input
.
shape
...
...
@@ -78,7 +78,7 @@ def all2all_head2seq(input):
output
=
torch
.
empty_like
(
input_t
)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist
.
all_to_all_single
(
output
,
input_t
)
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output
=
output
.
reshape
(
heads
,
shard_seq_len
,
hidden_dims
)
...
...
lightx2v/infer.py
View file @
60f0b6c1
import
argparse
import
torch
import
torch.distributed
as
dist
from
torch.distributed.device_mesh
import
init_device_mesh
import
json
from
lightx2v.utils.envs
import
*
...
...
@@ -25,10 +26,15 @@ from loguru import logger
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
if
config
.
parallel
_attn_type
:
if
config
.
parallel
:
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
config
[
"device_mesh"
]
=
init_device_mesh
(
"cuda"
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
CHECK_ENABLE_GRAPH_MODE
():
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
GraphRunner
(
default_runner
)
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
60f0b6c1
import
torch
import
math
from
..utils
import
compute_freqs
,
compute_freqs_causvid
,
compute_freqs_dist
,
apply_rotary_emb
from
..utils
import
compute_freqs
,
compute_freqs_causvid
,
apply_rotary_emb
from
lightx2v.utils.envs
import
*
from
..transformer_infer
import
WanTransformerInfer
...
...
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
View file @
60f0b6c1
...
...
@@ -2,12 +2,13 @@ import torch
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
lightx2v.models.networks.wan.infer.utils
import
compute_freqs_dist
,
compute_freqs_audio_dist
from
lightx2v.models.networks.wan.infer.utils
import
pad_freqs
class
WanTransformerDistInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
seq_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"seq_p"
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
x
=
self
.
dist_pre_process
(
x
)
...
...
@@ -17,14 +18,14 @@ class WanTransformerDistInfer(WanTransformerInfer):
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
dist_pre_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
...
...
@@ -36,16 +37,56 @@ class WanTransformerDistInfer(WanTransformerInfer):
return
x
def
dist_post_process
(
self
,
x
):
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
)
dist
.
all_gather
(
gathered_x
,
x
,
group
=
self
.
seq_p_group
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
return
combined_output
# 返回合并后的输出
def
compute_freqs_dist
(
self
,
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_audio_dist
(
self
,
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
60f0b6c1
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_
dist
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
.utils
import
compute_freqs
,
compute_freqs_
audio
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
lightx2v.common.offload.manager
import
(
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
...
...
@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
if
self
.
config
.
get
(
"
parallel
_attn_type"
,
None
)
:
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
attn_out
=
weights
.
self_attn_1_parallel
.
apply
(
q
=
q
,
k
=
k
,
...
...
@@ -375,6 +375,7 @@ class WanTransformerInfer(BaseTransformerInfer):
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
attention_module
=
weights
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
)
else
:
attn_out
=
weights
.
self_attn_1
.
apply
(
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
60f0b6c1
...
...
@@ -37,28 +37,6 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return
freqs_i
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
...
...
@@ -83,27 +61,6 @@ def pad_freqs(original_tensor, target_len):
return
padded_tensor
def
compute_freqs_dist
(
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
apply_rotary_emb
(
x
,
freqs_i
):
n
=
x
.
size
(
1
)
seq_len
=
freqs_i
.
size
(
0
)
...
...
lightx2v/models/networks/wan/model.py
View file @
60f0b6c1
import
os
import
torch
import
torch.distributed
as
dist
import
glob
import
json
from
lightx2v.common.ops.attn
import
MaskMap
...
...
@@ -69,7 +70,7 @@ class WanModel:
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
config
.
get
(
"
parallel
_attn_type"
,
None
)
:
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
self
.
transformer_infer_class
=
WanTransformerDistInfer
else
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
...
...
@@ -186,6 +187,10 @@ class WanModel:
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
self
.
config
.
parallel
.
cfg_p_size
>
1
:
self
.
infer
=
self
.
infer_with_cfg_parallel
else
:
self
.
infer
=
self
.
infer_wo_cfg_parallel
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -204,7 +209,7 @@ class WanModel:
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
...
...
@@ -245,6 +250,29 @@ class WanModel:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_with_cfg_parallel
(
self
,
inputs
):
assert
self
.
config
[
"enable_cfg"
],
"enable_cfg must be True"
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
assert
dist
.
get_world_size
(
cfg_p_group
)
==
2
,
f
"cfg_p_world_size must be equal to 2"
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
else
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
60f0b6c1
...
...
@@ -191,8 +191,8 @@ class WanSelfAttention(WeightModule):
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
if
self
.
config
.
get
(
"
parallel
_attn_type"
,
None
)
:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
parallel
_attn_type"
]
]())
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
.
parallel
.
get
(
"seq_p_attn_type"
,
"ulysses"
)
]())
if
self
.
quant_method
in
[
"advanced_ptq"
]:
self
.
add_module
(
...
...
lightx2v/models/runners/default_runner.py
View file @
60f0b6c1
...
...
@@ -43,7 +43,7 @@ class DefaultRunner(BaseRunner):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_t2v
def
set_init_device
(
self
):
if
self
.
config
.
parallel
_attn_type
:
if
self
.
config
.
parallel
:
cur_rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
cur_rank
)
if
self
.
config
.
cpu_offload
:
...
...
@@ -237,7 +237,7 @@ class DefaultRunner(BaseRunner):
else
:
fps
=
self
.
config
.
get
(
"fps"
,
16
)
if
not
self
.
config
.
get
(
"parallel_attn_type"
,
None
)
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
(
)
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Saving video to
{
self
.
config
.
save_video_path
}
"
)
if
self
.
config
[
"model_cls"
]
!=
"wan2.2"
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
60f0b6c1
...
...
@@ -124,7 +124,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init_device
,
"parallel"
:
self
.
config
.
parallel
_vae
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
...
...
@@ -136,7 +136,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init_device
,
"parallel"
:
self
.
config
.
parallel
_vae
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment