Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
7367d6c8
"vscode:/vscode.git/clone" did not exist on "0958967df8b1cebd488a58b60a61156525af3819"
Commit
7367d6c8
authored
Aug 14, 2025
by
helloyongyang
Browse files
remove unsed seq_p_group
parent
99a6f046
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
23 additions
and
39 deletions
+23
-39
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+0
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+1
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+2
-2
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+2
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+0
-4
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+11
-7
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+3
-3
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+0
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+0
-3
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+0
-9
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+2
-3
No files found.
lightx2v/models/input_encoders/hf/t5/model.py
View file @
7367d6c8
...
...
@@ -540,7 +540,6 @@ class T5EncoderModel:
t5_quantized
=
False
,
t5_quantized_ckpt
=
None
,
quant_scheme
=
None
,
seq_p_group
=
None
,
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
7367d6c8
...
...
@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
seq_p_group
=
None
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
self
.
seq_p_group
=
seq_p_group
if
self
.
quantized
:
self
.
checkpoint_path
=
clip_quantized_ckpt
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
7367d6c8
...
...
@@ -16,8 +16,8 @@ class WanAudioModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
...
...
lightx2v/models/networks/wan/causvid_model.py
View file @
7367d6c8
...
...
@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
...
...
lightx2v/models/networks/wan/distill_model.py
View file @
7367d6c8
...
...
@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
7367d6c8
...
...
@@ -74,10 +74,6 @@ class WanPreInfer:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
# wan2.2_moe会对t做扩展,我们发现这里做不做影响不大,而且做了拓展会增加耗时,目前忠实原作代码,后续可以考虑去掉
if
self
.
config
[
"model_cls"
]
==
"wan2.2_moe"
:
t
=
t
.
expand
(
seq_lens
[
0
])
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
).
to
(
x
.
device
)
...
...
lightx2v/models/networks/wan/model.py
View file @
7367d6c8
...
...
@@ -41,12 +41,16 @@ class WanModel:
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
seq_p_group
=
seq_p_group
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
...
...
@@ -390,11 +394,11 @@ class WanModel:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
.
startswith
(
"wan2.2"
)
:
padding_size
=
(
world_size
-
(
embed0
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
embed0
=
F
.
pad
(
embed0
,
(
0
,
0
,
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
padding_size
))
#
if self.config["model_cls"]
==
"wan2.2":
#
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
#
if padding_size > 0:
#
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
#
embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out
.
x
=
x
pre_infer_out
.
embed
=
embed
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
7367d6c8
...
...
@@ -435,8 +435,8 @@ class WanAudioRunner(WanRunner): # type:ignore
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
if
self
.
model
.
transformer_infer
.
seq_p_group
is
not
None
:
seq_p_group
=
self
.
model
.
transformer_infer
.
seq_p_group
if
self
.
config
[
"seq_parallel"
]
:
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
seq_p_group
=
None
...
...
@@ -619,7 +619,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
logger
.
info
(
f
"Loaded base model:
{
self
.
config
.
model_path
}
"
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
7367d6c8
...
...
@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner):
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
7367d6c8
...
...
@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner):
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
if
use_low_lora
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
7367d6c8
...
...
@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video
class
WanRunner
(
DefaultRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
device_mesh
=
self
.
config
.
get
(
"device_mesh"
)
if
device_mesh
is
not
None
:
self
.
seq_p_group
=
device_mesh
.
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
def
load_transformer
(
self
):
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
...
...
@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner):
clip_quantized
=
clip_quantized
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
quant_scheme
=
clip_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
cpu_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
)),
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
)
...
...
@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner):
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
t5_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
)
text_encoders
=
[
text_encoder
]
return
text_encoders
...
...
@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner):
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"seq_p_group"
:
self
.
seq_p_group
,
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
task
!=
"i2v"
:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
7367d6c8
...
...
@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
cpu_offload
=
False
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
cpu_offload
=
False
,
**
kwargs
):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
...
...
@@ -795,7 +795,6 @@ class WanVAE:
device
=
"cuda"
,
parallel
=
False
,
use_tiling
=
False
,
seq_p_group
=
None
,
cpu_offload
=
False
,
):
self
.
dtype
=
dtype
...
...
@@ -845,7 +844,7 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment