Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
347a54a3
Commit
347a54a3
authored
Aug 11, 2025
by
gushiqiao
Browse files
Fix load weights bug.
parent
8d32295d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
154 additions
and
122 deletions
+154
-122
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+3
-3
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+9
-14
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+19
-25
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+70
-45
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+4
-14
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+6
-2
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+39
-17
No files found.
lightx2v/models/input_encoders/hf/t5/model.py
View file @
347a54a3
...
...
@@ -10,7 +10,7 @@ from loguru import logger
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearFp8
,
VllmQuantLinearInt8
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
from
.tokenizer
import
HuggingfaceTokenizer
...
...
@@ -571,8 +571,8 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
)
weights_di
t
c
=
load_weights
_distributed
(
self
.
checkpoint_path
)
model
.
load_state_dict
(
weights_di
t
c
)
weights_dic
t
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dic
t
)
self
.
model
=
model
if
shard_fn
is
not
None
:
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
347a54a3
...
...
@@ -11,7 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearFp8
,
VllmQuantLinearInt8
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"XLMRobertaCLIP"
,
...
...
@@ -418,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
seq_p_group
=
None
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
seq_p_group
=
None
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
self
.
seq_p_group
=
seq_p_group
if
self
.
quantized
:
...
...
@@ -434,28 +436,21 @@ class CLIPModel:
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
load_weights_distributed
(
self
.
checkpoint_path
)
keys
=
list
(
weight_dict
.
keys
())
for
key
in
keys
:
if
"textual"
in
key
:
weight_dict
.
pop
(
key
)
weight_dict
=
load_weights
(
self
.
checkpoint_path
,
cpu_offload
=
cpu_offload
,
remove_key
=
"textual"
)
self
.
model
.
load_state_dict
(
weight_dict
)
def
visual
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
def
visual
(
self
,
videos
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
use_31_block
=
getattr
(
args
,
"use_31_block"
,
True
)
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
use_31_block
)
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
self
.
use_31_block
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cpu
()
return
out
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
347a54a3
...
...
@@ -98,6 +98,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
...
...
@@ -115,10 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
return
x
...
...
@@ -145,9 +144,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
...
...
@@ -164,6 +162,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
...
...
@@ -189,9 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
...
...
@@ -216,6 +213,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
obj_key
=
(
block_idx
,
phase_idx
)
...
...
@@ -251,9 +249,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
...
@@ -290,16 +286,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
@
torch
.
_dynamo
.
disable
def
_apply_audio_dit
(
self
,
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
):
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
...
...
@@ -309,13 +298,11 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
weights
.
compute_phases
[
0
],
embed0
,
...
...
@@ -331,7 +318,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
x
,
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
return
x
def
infer_modulation
(
self
,
weights
,
embed0
):
...
...
@@ -516,12 +503,19 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
def
post_process
(
self
,
x
,
y
,
c_gate_msa
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
=
None
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y
.
to
(
self
.
sensitive_layer_dtype
)
*
c_gate_msa
.
squeeze
()
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
())
# Apply audio_dit if available
if
audio_dit_blocks
is
not
None
and
hasattr
(
self
,
"block_idx"
):
for
ipa_out
in
audio_dit_blocks
:
if
self
.
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
self
.
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
if
self
.
clean_cuda_cache
:
del
y
,
c_gate_msa
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/model.py
View file @
347a54a3
...
...
@@ -105,6 +105,18 @@ class WanModel:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_should_load_weights
(
self
):
"""Determine if current rank should load weights from disk."""
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
# Single GPU mode
return
True
elif
dist
.
is_initialized
():
# Multi-GPU mode, only rank 0 loads
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Loading weights from
{
self
.
model_path
}
"
)
return
True
return
False
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
...
...
@@ -190,64 +202,31 @@ class WanModel:
}
if
weight_dict
is
None
:
is_weight_loader
=
False
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
elif
dist
.
is_initialized
():
if
dist
.
get_rank
()
==
0
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
cpu_weight_dict
=
{}
is_weight_loader
=
self
.
_should_load_weights
()
if
is_weight_loader
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
cpu_weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
# Load original weights
weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
# Load quantized weights
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
cpu_
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
cpu_weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
# 单卡模式
self
.
original_weight_dict
=
{}
init_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
for
key
,
tensor
in
cpu_weight_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
tensor
.
to
(
init_device
,
non_blocking
=
True
)
else
:
global_src_rank
=
0
meta_dict
=
{}
if
is_weight_loader
:
for
key
,
tensor
in
cpu_weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
obj_list
=
[
meta_dict
]
if
is_weight_loader
else
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
self
.
original_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
"cuda"
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
weight_dict
=
self
.
_distribute_weights_multi_gpu
(
weight_dict
,
is_weight_loader
)
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
self
.
original_weight_dict
[
key
]
if
is_weight_loader
:
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
global_src_rank
)
if
is_weight_loader
:
del
cpu_weight_dict
self
.
original_weight_dict
=
weight_dict
else
:
self
.
original_weight_dict
=
weight_dict
#
i
nit weights
#
I
nit
ialize
weight
container
s
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
# Load weights into containers
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
...
...
@@ -255,6 +234,52 @@ class WanModel:
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
def
_distribute_weights_multi_gpu
(
self
,
weight_dict
,
is_weight_loader
):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank
=
0
# Determine target device for distribution
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
if
is_weight_loader
:
# Create metadata for broadcasting
meta_dict
=
{}
for
key
,
tensor
in
weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
# Broadcast metadata to all ranks
obj_list
=
[
meta_dict
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
else
:
# Non-loader ranks receive metadata
obj_list
=
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
# Create empty tensors on target device for all ranks
distributed_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
# Synchronize before broadcasting
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
dist
.
barrier
()
# Broadcast weights from rank 0 to all ranks
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
# Copy weights to broadcast tensor
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
# Broadcast to all ranks
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
347a54a3
...
...
@@ -668,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
]
,
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
347a54a3
...
...
@@ -84,6 +84,8 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt
=
clip_quantized_ckpt
,
quant_scheme
=
clip_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
cpu_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
)),
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
)
return
image_encoder
...
...
@@ -233,7 +235,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]]
,
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
())
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
347a54a3
...
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
loguru
import
logger
from
lightx2v.utils.utils
import
load_weights
_distributed
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"WanVAE"
,
...
...
@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
cpu_offload
=
False
,
**
kwargs
):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
...
...
@@ -780,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model
=
WanVAE_
(
**
cfg
)
# load checkpoint
weights_dict
=
load_weights_distributed
(
pretrained_path
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dict
,
assign
=
True
)
return
model
...
...
@@ -846,16 +845,7 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
self
.
model
=
(
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
)
.
eval
()
.
requires_grad_
(
False
)
.
to
(
device
)
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
347a54a3
...
...
@@ -6,6 +6,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
"Wan2_2_VAE"
,
]
...
...
@@ -806,7 +808,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
16
,
dim
=
160
,
device
=
"cpu"
,
cpu_offload
=
False
,
**
kwargs
):
# params
cfg
=
dict
(
dim
=
dim
,
...
...
@@ -825,7 +827,8 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
pretrained_path
,
map_location
=
device
),
assign
=
True
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dict
)
return
model
...
...
@@ -955,6 +958,7 @@ class Wan2_2_VAE:
dim
=
c_dim
,
dim_mult
=
dim_mult
,
temperal_downsample
=
temperal_downsample
,
cpu_offload
=
cpu_offload
,
)
.
eval
()
.
requires_grad_
(
False
)
...
...
lightx2v/utils/utils.py
View file @
347a54a3
...
...
@@ -324,51 +324,73 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise
FileNotFoundError
(
f
"No GGUF model files (.gguf) found.
\n
Please download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
load_weights
_distributed
(
checkpoint_path
):
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
):
if
not
dist
.
is_initialized
():
# Single GPU mode
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
return
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
is_leader
=
False
# Multi-GPU mode
is_weight_loader
=
False
current_rank
=
dist
.
get_rank
()
if
current_rank
==
0
:
is_
le
ader
=
True
is_
weight_lo
ader
=
True
cpu_weight_dict
=
{}
if
is_
le
ader
:
#
#
rank0在 CPU 上加载完整的权重字典
if
is_
weight_lo
ader
:
#
rank0在 CPU 上加载完整的权重字典
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
cpu_weight_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
for
key
in
list
(
cpu_weight_dict
.
keys
()):
if
remove_key
and
remove_key
in
key
:
cpu_weight_dict
.
pop
(
key
)
# 同步字典的结构
meta_dict
=
{}
if
is_
le
ader
:
if
is_
weight_lo
ader
:
for
key
,
tensor
in
cpu_weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
obj_list
=
[
meta_dict
]
if
is_
le
ader
else
[
None
]
obj_list
=
[
meta_dict
]
if
is_
weight_lo
ader
else
[
None
]
# 获取rank0的全局 rank 用于广播
src_global_rank
=
0
dist
.
broadcast_object_list
(
obj_list
,
src
=
src_global_rank
)
synced_meta_dict
=
obj_list
[
0
]
# 所有进程所在的GPU上创建空的权重字典
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
gpu_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
# 根据offload配置决定目标设备
if
cpu_offload
:
# Multi-GPU + offload: weights on CPU
target_device
=
"cpu"
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
# CPU分发使用普通barrier
dist
.
barrier
()
else
:
# Multi-GPU + non-offload: weights on GPU
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
# GPU分发使用CUDA barrier
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
# 广播权重
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
gpu_weight_dict
[
key
]
if
is_leader
:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
if
is_weight_loader
:
# rank0将CPU权重拷贝到目标设备,准备广播
if
cpu_offload
:
# CPU模式:直接复制
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
else
:
# GPU模式:先复制到当前GPU,再广播
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
# 广播到所有ranks
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
is_
le
ader
:
if
is_
weight_lo
ader
:
del
cpu_weight_dict
return
gpu_weight_dict
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
def
masks_like
(
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment