Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ed1a937b
Commit
ed1a937b
authored
Sep 18, 2025
by
gushiqiao
Committed by
GitHub
Sep 18, 2025
Browse files
[Fix] Update load weights params list (#317)
parent
3b896f9c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
12 deletions
+33
-12
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+2
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+2
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-1
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+2
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+4
-0
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+4
-3
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+17
-4
No files found.
lightx2v/models/input_encoders/hf/t5/model.py
View file @
ed1a937b
...
@@ -540,6 +540,7 @@ class T5EncoderModel:
...
@@ -540,6 +540,7 @@ class T5EncoderModel:
t5_quantized
=
False
,
t5_quantized
=
False
,
t5_quantized_ckpt
=
None
,
t5_quantized_ckpt
=
None
,
quant_scheme
=
None
,
quant_scheme
=
None
,
load_from_rank0
=
False
,
):
):
self
.
text_len
=
text_len
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -570,7 +571,7 @@ class T5EncoderModel:
...
@@ -570,7 +571,7 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
.
requires_grad_
(
False
)
)
)
weights_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
)
weights_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
,
load_from_rank0
=
load_from_rank0
)
model
.
load_state_dict
(
weights_dict
)
model
.
load_state_dict
(
weights_dict
)
self
.
model
=
model
self
.
model
=
model
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
ed1a937b
...
@@ -418,7 +418,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
...
@@ -418,7 +418,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
quantized
=
clip_quantized
self
.
quantized
=
clip_quantized
...
@@ -435,7 +435,7 @@ class CLIPModel:
...
@@ -435,7 +435,7 @@ class CLIPModel:
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
,
remove_key
=
"textual"
)
weight_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
,
remove_key
=
"textual"
,
load_from_rank0
=
load_from_rank0
)
self
.
model
.
load_state_dict
(
weight_dict
)
self
.
model
.
load_state_dict
(
weight_dict
)
def
visual
(
self
,
videos
):
def
visual
(
self
,
videos
):
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
ed1a937b
...
@@ -38,7 +38,8 @@ class WanAudioModel(WanModel):
...
@@ -38,7 +38,8 @@ class WanAudioModel(WanModel):
self
.
config
.
adapter_model_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
adapter_model_name
)
self
.
config
.
adapter_model_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
adapter_model_name
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
if
not
dist
.
is_initialized
()
and
not
adapter_offload
:
if
not
dist
.
is_initialized
()
and
not
adapter_offload
:
for
key
in
self
.
adapter_weights_dict
:
for
key
in
self
.
adapter_weights_dict
:
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
cuda
()
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
cuda
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
ed1a937b
...
@@ -735,7 +735,8 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -735,7 +735,8 @@ class WanAudioRunner(WanRunner): # type:ignore
)
)
audio_adapter
.
to
(
device
)
audio_adapter
.
to
(
device
)
weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
,
load_from_rank0
=
load_from_rank0
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
ed1a937b
...
@@ -89,6 +89,7 @@ class WanRunner(DefaultRunner):
...
@@ -89,6 +89,7 @@ class WanRunner(DefaultRunner):
quant_scheme
=
clip_quant_scheme
,
quant_scheme
=
clip_quant_scheme
,
cpu_offload
=
clip_offload
,
cpu_offload
=
clip_offload
,
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
),
)
)
return
image_encoder
return
image_encoder
...
@@ -130,6 +131,7 @@ class WanRunner(DefaultRunner):
...
@@ -130,6 +131,7 @@ class WanRunner(DefaultRunner):
t5_quantized
=
t5_quantized
,
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
t5_quant_scheme
,
quant_scheme
=
t5_quant_scheme
,
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
),
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
return
text_encoders
return
text_encoders
...
@@ -149,6 +151,7 @@ class WanRunner(DefaultRunner):
...
@@ -149,6 +151,7 @@ class WanRunner(DefaultRunner):
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"dtype"
:
GET_DTYPE
(),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
}
}
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"vace"
]:
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"vace"
]:
return
None
return
None
...
@@ -170,6 +173,7 @@ class WanRunner(DefaultRunner):
...
@@ -170,6 +173,7 @@ class WanRunner(DefaultRunner):
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"dtype"
:
GET_DTYPE
(),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
}
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
self
.
tiny_vae_name
)
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
self
.
tiny_vae_name
)
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
ed1a937b
...
@@ -761,7 +761,7 @@ class WanVAE_(nn.Module):
...
@@ -761,7 +761,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
cpu_offload
=
False
,
dtype
=
torch
.
float
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
cpu_offload
=
False
,
dtype
=
torch
.
float
,
load_from_rank0
=
False
,
**
kwargs
):
"""
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
"""
...
@@ -782,7 +782,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
...
@@ -782,7 +782,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
model
=
WanVAE_
(
**
cfg
)
model
=
WanVAE_
(
**
cfg
)
# load checkpoint
# load checkpoint
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
,
load_from_rank0
=
load_from_rank0
)
for
k
in
weights_dict
.
keys
():
for
k
in
weights_dict
.
keys
():
if
weights_dict
[
k
].
dtype
!=
dtype
:
if
weights_dict
[
k
].
dtype
!=
dtype
:
weights_dict
[
k
]
=
weights_dict
[
k
].
to
(
dtype
)
weights_dict
[
k
]
=
weights_dict
[
k
].
to
(
dtype
)
...
@@ -802,6 +802,7 @@ class WanVAE:
...
@@ -802,6 +802,7 @@ class WanVAE:
use_tiling
=
False
,
use_tiling
=
False
,
cpu_offload
=
False
,
cpu_offload
=
False
,
use_2d_split
=
True
,
use_2d_split
=
True
,
load_from_rank0
=
False
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
...
@@ -888,7 +889,7 @@ class WanVAE:
...
@@ -888,7 +889,7 @@ class WanVAE:
}
}
# init model
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
).
eval
().
requires_grad_
(
False
).
to
(
device
).
to
(
dtype
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
,
load_from_rank0
=
load_from_rank0
).
eval
().
requires_grad_
(
False
).
to
(
device
).
to
(
dtype
)
def
_calculate_2d_grid
(
self
,
latent_height
,
latent_width
,
world_size
):
def
_calculate_2d_grid
(
self
,
latent_height
,
latent_width
,
world_size
):
if
(
latent_height
,
latent_width
,
world_size
)
in
self
.
grid_table
:
if
(
latent_height
,
latent_width
,
world_size
)
in
self
.
grid_table
:
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
ed1a937b
...
@@ -812,7 +812,7 @@ class WanVAE_(nn.Module):
...
@@ -812,7 +812,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
cpu_offload
=
False
,
dtype
=
torch
.
float32
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
cpu_offload
=
False
,
dtype
=
torch
.
float32
,
load_from_rank0
=
False
,
**
kwargs
):
# params
# params
cfg
=
dict
(
cfg
=
dict
(
dim
=
dim
,
dim
=
dim
,
...
@@ -831,7 +831,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
...
@@ -831,7 +831,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
# load checkpoint
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
,
load_from_rank0
=
load_from_rank0
)
for
k
in
weights_dict
.
keys
():
for
k
in
weights_dict
.
keys
():
if
weights_dict
[
k
].
dtype
!=
dtype
:
if
weights_dict
[
k
].
dtype
!=
dtype
:
weights_dict
[
k
]
=
weights_dict
[
k
].
to
(
dtype
)
weights_dict
[
k
]
=
weights_dict
[
k
].
to
(
dtype
)
...
@@ -842,7 +842,18 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
...
@@ -842,7 +842,18 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
class
Wan2_2_VAE
:
class
Wan2_2_VAE
:
def
__init__
(
def
__init__
(
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
cpu_offload
=
False
,
offload_cache
=
False
,
**
kwargs
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
cpu_offload
=
False
,
offload_cache
=
False
,
load_from_rank0
=
False
,
**
kwargs
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
...
@@ -961,7 +972,9 @@ class Wan2_2_VAE:
...
@@ -961,7 +972,9 @@ class Wan2_2_VAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
# init model
self
.
model
=
(
self
.
model
=
(
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
dim
=
c_dim
,
dim_mult
=
dim_mult
,
temperal_downsample
=
temperal_downsample
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
)
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
dim
=
c_dim
,
dim_mult
=
dim_mult
,
temperal_downsample
=
temperal_downsample
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
,
load_from_rank0
=
load_from_rank0
)
.
eval
()
.
eval
()
.
requires_grad_
(
False
)
.
requires_grad_
(
False
)
.
to
(
device
)
.
to
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment