"docs/vscode:/vscode.git/clone" did not exist on "8e3affc6690d30a3adbb3bbe0869171ea863ccef"
Commit d502fab6 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix load weights bug.

Fix load weights bug.
parents 8d32295d 347a54a3
...@@ -10,7 +10,7 @@ from loguru import logger ...@@ -10,7 +10,7 @@ from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights_distributed from lightx2v.utils.utils import load_weights
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
...@@ -571,8 +571,8 @@ class T5EncoderModel: ...@@ -571,8 +571,8 @@ class T5EncoderModel:
.requires_grad_(False) .requires_grad_(False)
) )
weights_ditc = load_weights_distributed(self.checkpoint_path) weights_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_ditc) model.load_state_dict(weights_dict)
self.model = model self.model = model
if shard_fn is not None: if shard_fn is not None:
......
...@@ -11,7 +11,7 @@ from loguru import logger ...@@ -11,7 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights_distributed from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
"XLMRobertaCLIP", "XLMRobertaCLIP",
...@@ -418,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -418,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, seq_p_group=None): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, seq_p_group=None):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.quantized = clip_quantized self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
self.seq_p_group = seq_p_group self.seq_p_group = seq_p_group
if self.quantized: if self.quantized:
...@@ -434,28 +436,21 @@ class CLIPModel: ...@@ -434,28 +436,21 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
) )
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights_distributed(self.checkpoint_path) weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual")
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
self.model.load_state_dict(weight_dict) self.model.load_state_dict(weight_dict)
def visual(self, videos, args): def visual(self, videos):
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
use_31_block = getattr(args, "use_31_block", True)
# preprocess # preprocess
size = (self.model.image_size,) * 2 size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos]) videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward # forward
with torch.amp.autocast("cuda", dtype=self.dtype): with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=use_31_block) out = self.model.visual(videos, use_31_block=self.use_31_block)
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cpu() self.to_cpu()
return out return out
......
...@@ -98,6 +98,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -98,6 +98,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda() self.weights_stream_mgr.active_weights[0].to_cuda()
...@@ -115,10 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -115,10 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens, seq_lens,
freqs, freqs,
context, context,
audio_dit_blocks,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
return x return x
...@@ -145,9 +144,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -145,9 +144,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens, seq_lens,
freqs, freqs,
context, context,
audio_dit_blocks,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
...@@ -164,6 +162,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -164,6 +162,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx] phase = weights.blocks[block_idx].compute_phases[phase_idx]
...@@ -189,9 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -189,9 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
...@@ -216,6 +213,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -216,6 +213,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks) self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num): for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx) obj_key = (block_idx, phase_idx)
...@@ -251,9 +249,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -251,9 +249,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1): if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
...@@ -290,16 +286,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -290,16 +286,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0 freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis return freqs_cis
@torch._dynamo.disable
def _apply_audio_dit(self, x, block_idx, grid_sizes, audio_dit_blocks):
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
self.block_idx = block_idx
x = self.infer_block( x = self.infer_block(
weights.blocks[block_idx], weights.blocks[block_idx],
grid_sizes, grid_sizes,
...@@ -309,13 +298,11 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -309,13 +298,11 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens, seq_lens,
freqs, freqs,
context, context,
audio_dit_blocks,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0], weights.compute_phases[0],
embed0, embed0,
...@@ -331,7 +318,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -331,7 +318,7 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
return x return x
def infer_modulation(self, weights, embed0): def infer_modulation(self, weights, embed0):
...@@ -516,12 +503,19 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -516,12 +503,19 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def post_process(self, x, y, c_gate_msa): def post_process(self, x, y, c_gate_msa, grid_sizes, audio_dit_blocks=None):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else: else:
x.add_(y * c_gate_msa.squeeze()) x.add_(y * c_gate_msa.squeeze())
# Apply audio_dit if available
if audio_dit_blocks is not None and hasattr(self, "block_idx"):
for ipa_out in audio_dit_blocks:
if self.block_idx in ipa_out:
cur_modify = ipa_out[self.block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
if self.clean_cuda_cache: if self.clean_cuda_cache:
del y, c_gate_msa del y, c_gate_msa
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -105,6 +105,18 @@ class WanModel: ...@@ -105,6 +105,18 @@ class WanModel:
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _should_load_weights(self):
"""Determine if current rank should load weights from disk."""
if self.config.get("device_mesh") is None:
# Single GPU mode
return True
elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads
if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
return { return {
...@@ -190,64 +202,31 @@ class WanModel: ...@@ -190,64 +202,31 @@ class WanModel:
} }
if weight_dict is None: if weight_dict is None:
is_weight_loader = False is_weight_loader = self._should_load_weights()
if self.config.get("device_mesh") is None:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
elif dist.is_initialized():
if dist.get_rank() == 0:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
cpu_weight_dict = {}
if is_weight_loader: if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) # Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
# Load quantized weights
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else: else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is None: # 单卡模式
self.original_weight_dict = {}
init_device = "cpu" if self.cpu_offload else "cuda"
for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to(init_device, non_blocking=True)
else:
global_src_rank = 0
meta_dict = {}
if is_weight_loader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
self.original_weight_dict = {} if self.config.get("device_mesh") is not None:
for key, meta in synced_meta_dict.items(): weight_dict = self._distribute_weights_multi_gpu(weight_dict, is_weight_loader)
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
dist.barrier(device_ids=[torch.cuda.current_device()]) self.original_weight_dict = weight_dict
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank)
if is_weight_loader:
del cpu_weight_dict
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
# Load weights into containers
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
...@@ -255,6 +234,52 @@ class WanModel: ...@@ -255,6 +234,52 @@ class WanModel:
del self.original_weight_dict del self.original_weight_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _distribute_weights_multi_gpu(self, weight_dict, is_weight_loader):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank = 0
# Determine target device for distribution
target_device = "cpu" if self.cpu_offload else "cuda"
if is_weight_loader:
# Create metadata for broadcasting
meta_dict = {}
for key, tensor in weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
# Broadcast metadata to all ranks
obj_list = [meta_dict]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
else:
# Non-loader ranks receive metadata
obj_list = [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
# Create empty tensors on target device for all ranks
distributed_weight_dict = {}
for key, meta in synced_meta_dict.items():
distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
# Synchronize before broadcasting
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
# Broadcast weights from rank 0 to all ranks
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
# Copy weights to broadcast tensor
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)
# Broadcast to all ranks
dist.broadcast(distributed_weight_dict[key], src=global_src_rank)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
......
...@@ -668,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -668,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic") cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder # clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.image_encoder.visual([cond_frms]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
# vae encode # vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
......
...@@ -84,6 +84,8 @@ class WanRunner(DefaultRunner): ...@@ -84,6 +84,8 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt=clip_quantized_ckpt, clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme, quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
use_31_block=self.config.get("use_31_block", True),
) )
return image_encoder return image_encoder
...@@ -233,7 +235,7 @@ class WanRunner(DefaultRunner): ...@@ -233,7 +235,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE()) clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from lightx2v.utils.utils import load_weights_distributed from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
"WanVAE", "WanVAE",
...@@ -759,7 +759,7 @@ class WanVAE_(nn.Module): ...@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, **kwargs): def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, cpu_offload=False, **kwargs):
""" """
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
""" """
...@@ -780,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, ...@@ -780,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model = WanVAE_(**cfg) model = WanVAE_(**cfg)
# load checkpoint # load checkpoint
weights_dict = load_weights_distributed(pretrained_path) weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict, assign=True) model.load_state_dict(weights_dict, assign=True)
return model return model
...@@ -846,16 +845,7 @@ class WanVAE: ...@@ -846,16 +845,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
# init model # init model
self.model = ( self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, seq_p_group=seq_p_group, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
_video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
seq_p_group=seq_p_group,
)
.eval()
.requires_grad_(False)
.to(device)
)
def current_device(self): def current_device(self):
return next(self.model.parameters()).device return next(self.model.parameters()).device
......
...@@ -6,6 +6,8 @@ import torch.nn as nn ...@@ -6,6 +6,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
"Wan2_2_VAE", "Wan2_2_VAE",
] ]
...@@ -806,7 +808,7 @@ class WanVAE_(nn.Module): ...@@ -806,7 +808,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, **kwargs):
# params # params
cfg = dict( cfg = dict(
dim=dim, dim=dim,
...@@ -825,7 +827,8 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): ...@@ -825,7 +827,8 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# load checkpoint # load checkpoint
logging.info(f"loading {pretrained_path}") logging.info(f"loading {pretrained_path}")
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict)
return model return model
...@@ -955,6 +958,7 @@ class Wan2_2_VAE: ...@@ -955,6 +958,7 @@ class Wan2_2_VAE:
dim=c_dim, dim=c_dim,
dim_mult=dim_mult, dim_mult=dim_mult,
temperal_downsample=temperal_downsample, temperal_downsample=temperal_downsample,
cpu_offload=cpu_offload,
) )
.eval() .eval()
.requires_grad_(False) .requires_grad_(False)
......
...@@ -324,51 +324,73 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None): ...@@ -324,51 +324,73 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path): def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized(): if not dist.is_initialized():
# Single GPU mode
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True) return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False # Multi-GPU mode
is_weight_loader = False
current_rank = dist.get_rank() current_rank = dist.get_rank()
if current_rank == 0: if current_rank == 0:
is_leader = True is_weight_loader = True
cpu_weight_dict = {} cpu_weight_dict = {}
if is_leader: ##rank0在 CPU 上加载完整的权重字典 if is_weight_loader: # rank0在 CPU 上加载完整的权重字典
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
# 同步字典的结构 # 同步字典的结构
meta_dict = {} meta_dict = {}
if is_leader: if is_weight_loader:
for key, tensor in cpu_weight_dict.items(): for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_leader else [None] obj_list = [meta_dict] if is_weight_loader else [None]
# 获取rank0的全局 rank 用于广播 # 获取rank0的全局 rank 用于广播
src_global_rank = 0 src_global_rank = 0
dist.broadcast_object_list(obj_list, src=src_global_rank) dist.broadcast_object_list(obj_list, src=src_global_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典 # 根据offload配置决定目标设备
target_device = torch.device(f"cuda:{current_rank}") if cpu_offload:
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} # Multi-GPU + offload: weights on CPU
target_device = "cpu"
dist.barrier(device_ids=[torch.cuda.current_device()]) distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# CPU分发使用普通barrier
dist.barrier()
else:
# Multi-GPU + non-offload: weights on GPU
target_device = torch.device(f"cuda:{current_rank}")
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# GPU分发使用CUDA barrier
dist.barrier(device_ids=[torch.cuda.current_device()])
# 广播权重
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key] tensor_to_broadcast = distributed_weight_dict[key]
if is_leader: if is_weight_loader:
# rank0将CPU权重拷贝到目标GPU,准备广播 # rank0将CPU权重拷贝到目标设备,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) if cpu_offload:
# CPU模式:直接复制
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
else:
# GPU模式:先复制到当前GPU,再广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
# 广播到所有ranks
dist.broadcast(tensor_to_broadcast, src=src_global_rank) dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_leader: if is_weight_loader:
del cpu_weight_dict del cpu_weight_dict
return gpu_weight_dict logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2): def masks_like(tensor, zero=False, generator=None, p=0.2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment