Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
88b7a2dd
Commit
88b7a2dd
authored
Aug 03, 2025
by
helloyongyang
Browse files
Support cfg parallel for T5 model
parent
2e5794c7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
21 deletions
+55
-21
lightx2v/infer.py
lightx2v/infer.py
+2
-12
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-2
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+1
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+19
-5
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+30
-0
No files found.
lightx2v/infer.py
View file @
88b7a2dd
import
argparse
import
argparse
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed.device_mesh
import
init_device_mesh
import
json
import
json
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.set_config
import
set_config
from
lightx2v.utils.set_config
import
set_config
,
print_config
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
...
@@ -26,15 +25,6 @@ from loguru import logger
...
@@ -26,15 +25,6 @@ from loguru import logger
def
init_runner
(
config
):
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
seed_all
(
config
.
seed
)
if
config
.
parallel
:
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
config
[
"device_mesh"
]
=
init_device_mesh
(
"cuda"
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
CHECK_ENABLE_GRAPH_MODE
():
if
CHECK_ENABLE_GRAPH_MODE
():
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
GraphRunner
(
default_runner
)
runner
=
GraphRunner
(
default_runner
)
...
@@ -73,7 +63,7 @@ def main():
...
@@ -73,7 +63,7 @@ def main():
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
(
"Total Cost"
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
print_config
(
config
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
88b7a2dd
...
@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
if
self
.
config
[
"seq_parallel"
]
:
attn_out
=
weights
.
self_attn_1_parallel
.
apply
(
attn_out
=
weights
.
self_attn_1_parallel
.
apply
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
...
lightx2v/models/networks/wan/model.py
View file @
88b7a2dd
...
@@ -70,7 +70,7 @@ class WanModel:
...
@@ -70,7 +70,7 @@ class WanModel:
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
if
self
.
config
[
"seq_parallel"
]
:
self
.
transformer_infer_class
=
WanTransformerDistInfer
self
.
transformer_infer_class
=
WanTransformerDistInfer
else
:
else
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
...
@@ -187,7 +187,7 @@ class WanModel:
...
@@ -187,7 +187,7 @@ class WanModel:
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
self
.
config
[
"
enable_cfg"
]
and
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
self
.
config
.
parallel
.
cfg_p_size
>
1
:
if
self
.
config
[
"
cfg_parallel"
]
:
self
.
infer_func
=
self
.
infer_with_cfg_parallel
self
.
infer_func
=
self
.
infer_with_cfg_parallel
else
:
else
:
self
.
infer_func
=
self
.
infer_wo_cfg_parallel
self
.
infer_func
=
self
.
infer_wo_cfg_parallel
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
88b7a2dd
...
@@ -191,7 +191,7 @@ class WanSelfAttention(WeightModule):
...
@@ -191,7 +191,7 @@ class WanSelfAttention(WeightModule):
else
:
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
if
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
self
.
config
.
parallel
.
seq_p_size
>
1
:
if
self
.
config
[
"seq_parallel"
]
:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
.
parallel
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
.
parallel
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
if
self
.
quant_method
in
[
"advanced_ptq"
]:
if
self
.
quant_method
in
[
"advanced_ptq"
]:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
88b7a2dd
...
@@ -176,16 +176,30 @@ class WanRunner(DefaultRunner):
...
@@ -176,16 +176,30 @@ class WanRunner(DefaultRunner):
def
run_text_encoder
(
self
,
text
,
img
):
def
run_text_encoder
(
self
,
text
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
text_encoder_output
=
{}
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
if
self
.
config
[
"cfg_parallel"
]:
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
text_encoder_output
=
{
"context"
:
context
}
else
:
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
text_encoder_output
=
{
"context_null"
:
context_null
}
else
:
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
text_encoder_output
=
{
"context"
:
context
,
"context_null"
:
context_null
,
}
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
text_encoders
[
0
]
del
self
.
text_encoders
[
0
]
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
text_encoder_output
[
"context"
]
=
context
text_encoder_output
[
"context_null"
]
=
context_null
return
text_encoder_output
return
text_encoder_output
def
run_image_encoder
(
self
,
img
):
def
run_image_encoder
(
self
,
img
):
...
...
lightx2v/utils/set_config.py
View file @
88b7a2dd
...
@@ -2,6 +2,8 @@ import json
...
@@ -2,6 +2,8 @@ import json
import
os
import
os
from
easydict
import
EasyDict
from
easydict
import
EasyDict
from
loguru
import
logger
from
loguru
import
logger
import
torch.distributed
as
dist
from
torch.distributed.tensor.device_mesh
import
init_device_mesh
def
get_default_config
():
def
get_default_config
():
...
@@ -19,6 +21,7 @@ def get_default_config():
...
@@ -19,6 +21,7 @@ def get_default_config():
"mm_config"
:
{},
"mm_config"
:
{},
"use_prompt_enhancer"
:
False
,
"use_prompt_enhancer"
:
False
,
"parallel"
:
False
,
"parallel"
:
False
,
"enable_cfg"
:
False
,
}
}
return
default_config
return
default_config
...
@@ -57,4 +60,31 @@ def set_config(args):
...
@@ -57,4 +60,31 @@ def set_config(args):
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
.
vae_stride
[
0
]
}
. Rounding to the nearest number."
)
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
.
vae_stride
[
0
]
}
. Rounding to the nearest number."
)
config
.
target_video_length
=
config
.
target_video_length
//
config
.
vae_stride
[
0
]
*
config
.
vae_stride
[
0
]
+
1
config
.
target_video_length
=
config
.
target_video_length
//
config
.
vae_stride
[
0
]
*
config
.
vae_stride
[
0
]
+
1
set_parallel_config
(
config
)
# parallel config
return
config
return
config
def
set_parallel_config
(
config
):
config
[
"seq_parallel"
]
=
False
config
[
"cfg_parallel"
]
=
False
if
config
.
parallel
:
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
config
[
"device_mesh"
]
=
init_device_mesh
(
"cuda"
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
config
.
parallel
and
config
.
parallel
.
get
(
"seq_p_size"
,
False
)
and
config
.
parallel
.
seq_p_size
>
1
:
config
[
"seq_parallel"
]
=
True
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
.
parallel
and
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
config
.
parallel
.
cfg_p_size
>
1
:
config
[
"cfg_parallel"
]
=
True
def
print_config
(
config
):
config_to_print
=
config
.
copy
()
config_to_print
.
pop
(
"device_mesh"
,
None
)
# Remove device_mesh if it exists
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment