Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d502fab6
Commit
d502fab6
authored
Aug 11, 2025
by
gushiqiao
Committed by
GitHub
Aug 11, 2025
Browse files
Fix load weights bug.
Fix load weights bug.
parents
8d32295d
347a54a3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
154 additions
and
122 deletions
+154
-122
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+3
-3
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+9
-14
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+19
-25
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+70
-45
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+4
-14
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+6
-2
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+39
-17
No files found.
lightx2v/models/input_encoders/hf/t5/model.py
View file @
d502fab6
...
...
@@ -10,7 +10,7 @@ from loguru import logger
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearFp8
,
VllmQuantLinearInt8
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
from
.tokenizer
import
HuggingfaceTokenizer
...
...
@@ -571,8 +571,8 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
)
weights_di
t
c
=
load_weights
_distributed
(
self
.
checkpoint_path
)
model
.
load_state_dict
(
weights_di
t
c
)
weights_dic
t
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dic
t
)
self
.
model
=
model
if
shard_fn
is
not
None
:
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
d502fab6
...
...
@@ -11,7 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearFp8
,
VllmQuantLinearInt8
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"XLMRobertaCLIP"
,
...
...
@@ -418,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
seq_p_group
=
None
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
seq_p_group
=
None
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
self
.
seq_p_group
=
seq_p_group
if
self
.
quantized
:
...
...
@@ -434,28 +436,21 @@ class CLIPModel:
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
load_weights_distributed
(
self
.
checkpoint_path
)
keys
=
list
(
weight_dict
.
keys
())
for
key
in
keys
:
if
"textual"
in
key
:
weight_dict
.
pop
(
key
)
weight_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
,
remove_key
=
"textual"
)
self
.
model
.
load_state_dict
(
weight_dict
)
def
visual
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
def
visual
(
self
,
videos
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
use_31_block
=
getattr
(
args
,
"use_31_block"
,
True
)
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
use_31_block
)
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
self
.
use_31_block
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cpu
()
return
out
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
d502fab6
...
...
@@ -98,6 +98,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
...
...
@@ -115,10 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
return
x
...
...
@@ -145,9 +144,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
...
...
@@ -164,6 +162,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
...
...
@@ -189,9 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
...
...
@@ -216,6 +213,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
obj_key
=
(
block_idx
,
phase_idx
)
...
...
@@ -251,9 +249,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
...
@@ -290,16 +286,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
@
torch
.
_dynamo
.
disable
def
_apply_audio_dit
(
self
,
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
):
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
...
...
@@ -309,13 +298,11 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
weights
.
compute_phases
[
0
],
embed0
,
...
...
@@ -331,7 +318,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
x
,
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
return
x
def
infer_modulation
(
self
,
weights
,
embed0
):
...
...
@@ -516,12 +503,19 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
def
post_process
(
self
,
x
,
y
,
c_gate_msa
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
=
None
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y
.
to
(
self
.
sensitive_layer_dtype
)
*
c_gate_msa
.
squeeze
()
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
())
# Apply audio_dit if available
if
audio_dit_blocks
is
not
None
and
hasattr
(
self
,
"block_idx"
):
for
ipa_out
in
audio_dit_blocks
:
if
self
.
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
self
.
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
if
self
.
clean_cuda_cache
:
del
y
,
c_gate_msa
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/model.py
View file @
d502fab6
...
...
@@ -105,6 +105,18 @@ class WanModel:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_should_load_weights
(
self
):
"""Determine if current rank should load weights from disk."""
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
# Single GPU mode
return
True
elif
dist
.
is_initialized
():
# Multi-GPU mode, only rank 0 loads
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Loading weights from
{
self
.
model_path
}
"
)
return
True
return
False
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
...
...
@@ -190,64 +202,31 @@ class WanModel:
}
if
weight_dict
is
None
:
is_weight_loader
=
False
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
elif
dist
.
is_initialized
():
if
dist
.
get_rank
()
==
0
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
cpu_weight_dict
=
{}
is_weight_loader
=
self
.
_should_load_weights
()
if
is_weight_loader
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
cpu_weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
# Load original weights
weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
# Load quantized weights
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
cpu_
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
cpu_weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
# 单卡模式
self
.
original_weight_dict
=
{}
init_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
for
key
,
tensor
in
cpu_weight_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
tensor
.
to
(
init_device
,
non_blocking
=
True
)
else
:
global_src_rank
=
0
meta_dict
=
{}
if
is_weight_loader
:
for
key
,
tensor
in
cpu_weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
obj_list
=
[
meta_dict
]
if
is_weight_loader
else
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
self
.
original_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
"cuda"
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
weight_dict
=
self
.
_distribute_weights_multi_gpu
(
weight_dict
,
is_weight_loader
)
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
self
.
original_weight_dict
[
key
]
if
is_weight_loader
:
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
global_src_rank
)
if
is_weight_loader
:
del
cpu_weight_dict
self
.
original_weight_dict
=
weight_dict
else
:
self
.
original_weight_dict
=
weight_dict
#
i
nit weights
#
I
nit
ialize
weight
container
s
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
# Load weights into containers
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
...
...
@@ -255,6 +234,52 @@ class WanModel:
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
def
_distribute_weights_multi_gpu
(
self
,
weight_dict
,
is_weight_loader
):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank
=
0
# Determine target device for distribution
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
if
is_weight_loader
:
# Create metadata for broadcasting
meta_dict
=
{}
for
key
,
tensor
in
weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
# Broadcast metadata to all ranks
obj_list
=
[
meta_dict
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
else
:
# Non-loader ranks receive metadata
obj_list
=
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
# Create empty tensors on target device for all ranks
distributed_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
# Synchronize before broadcasting
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
dist
.
barrier
()
# Broadcast weights from rank 0 to all ranks
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
# Copy weights to broadcast tensor
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
# Broadcast to all ranks
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d502fab6
...
...
@@ -668,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
]
,
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
d502fab6
...
...
@@ -84,6 +84,8 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt
=
clip_quantized_ckpt
,
quant_scheme
=
clip_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
cpu_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
)),
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
)
return
image_encoder
...
...
@@ -233,7 +235,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]]
,
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
())
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
d502fab6
...
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
loguru
import
logger
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"WanVAE"
,
...
...
@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
cpu_offload
=
False
,
**
kwargs
):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
...
...
@@ -780,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model
=
WanVAE_
(
**
cfg
)
# load checkpoint
weights_dict
=
load_weights_distributed
(
pretrained_path
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dict
,
assign
=
True
)
return
model
...
...
@@ -846,16 +845,7 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
self
.
model
=
(
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
)
.
eval
()
.
requires_grad_
(
False
)
.
to
(
device
)
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
d502fab6
...
...
@@ -6,6 +6,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"Wan2_2_VAE"
,
]
...
...
@@ -806,7 +808,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
cpu_offload
=
False
,
**
kwargs
):
# params
cfg
=
dict
(
dim
=
dim
,
...
...
@@ -825,7 +827,8 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
pretrained_path
,
map_location
=
device
),
assign
=
True
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dict
)
return
model
...
...
@@ -955,6 +958,7 @@ class Wan2_2_VAE:
dim
=
c_dim
,
dim_mult
=
dim_mult
,
temperal_downsample
=
temperal_downsample
,
cpu_offload
=
cpu_offload
,
)
.
eval
()
.
requires_grad_
(
False
)
...
...
lightx2v/utils/utils.py
View file @
d502fab6
...
...
@@ -324,51 +324,73 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise
FileNotFoundError
(
f
"No GGUF model files (.gguf) found.
\n
Please download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
load_weights
_distributed
(
checkpoint_path
):
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
):
if
not
dist
.
is_initialized
():
# Single GPU mode
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
return
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
is_leader
=
False
# Multi-GPU mode
is_weight_loader
=
False
current_rank
=
dist
.
get_rank
()
if
current_rank
==
0
:
is_
le
ader
=
True
is_
weight_lo
ader
=
True
cpu_weight_dict
=
{}
if
is_
le
ader
:
#
#
rank0在 CPU 上加载完整的权重字典
if
is_
weight_lo
ader
:
#
rank0在 CPU 上加载完整的权重字典
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
cpu_weight_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
for
key
in
list
(
cpu_weight_dict
.
keys
()):
if
remove_key
and
remove_key
in
key
:
cpu_weight_dict
.
pop
(
key
)
# 同步字典的结构
meta_dict
=
{}
if
is_
le
ader
:
if
is_
weight_lo
ader
:
for
key
,
tensor
in
cpu_weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
obj_list
=
[
meta_dict
]
if
is_
le
ader
else
[
None
]
obj_list
=
[
meta_dict
]
if
is_
weight_lo
ader
else
[
None
]
# 获取rank0的全局 rank 用于广播
src_global_rank
=
0
dist
.
broadcast_object_list
(
obj_list
,
src
=
src_global_rank
)
synced_meta_dict
=
obj_list
[
0
]
# 所有进程所在的GPU上创建空的权重字典
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
gpu_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
# 根据offload配置决定目标设备
if
cpu_offload
:
# Multi-GPU + offload: weights on CPU
target_device
=
"cpu"
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
# CPU分发使用普通barrier
dist
.
barrier
()
else
:
# Multi-GPU + non-offload: weights on GPU
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
# GPU分发使用CUDA barrier
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
# 广播权重
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
gpu_weight_dict
[
key
]
if
is_leader
:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
if
is_weight_loader
:
# rank0将CPU权重拷贝到目标设备,准备广播
if
cpu_offload
:
# CPU模式:直接复制
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
else
:
# GPU模式:先复制到当前GPU,再广播
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
# 广播到所有ranks
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
is_
le
ader
:
if
is_
weight_lo
ader
:
del
cpu_weight_dict
return
gpu_weight_dict
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
def
masks_like
(
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment