Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
068a47db
Commit
068a47db
authored
Aug 14, 2025
by
helloyongyang
Browse files
update parallel
parent
dc296c2f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1 addition
and
98 deletions
+1
-98
lightx2v/models/networks/wan/infer/dist_infer/__init__.py
lightx2v/models/networks/wan/infer/dist_infer/__init__.py
+0
-0
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
...models/networks/wan/infer/dist_infer/transformer_infer.py
+0
-98
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+1
-0
No files found.
lightx2v/models/networks/wan/infer/dist_infer/__init__.py
deleted
100644 → 0
View file @
dc296c2f
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
deleted
100755 → 0
View file @
dc296c2f
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
from
lightx2v.models.networks.wan.infer.utils
import
pad_freqs
class
WanTransformerDistInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
,
seq_p_group
=
None
):
super
().
__init__
(
config
)
self
.
seq_p_group
=
seq_p_group
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
x
,
embed0
=
self
.
dist_pre_process
(
x
,
embed0
)
x
=
super
().
infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
x
=
self
.
dist_post_process
(
x
)
return
x
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
self
.
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
self
.
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
dist_pre_process
(
self
,
x
,
embed0
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
].
startswith
(
"wan2.2"
):
embed0
=
torch
.
chunk
(
embed0
,
world_size
,
dim
=
0
)[
cur_rank
]
return
x
,
embed0
def
dist_post_process
(
self
,
x
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
,
group
=
self
.
seq_p_group
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
return
combined_output
# 返回合并后的输出
def
compute_freqs_dist
(
self
,
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_audio_dist
(
self
,
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
lightx2v/models/networks/wan/model.py
View file @
068a47db
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
loguru
import
logger
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment