Commit d502fab6 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix load weights bug.

Fix load weights bug.
parents 8d32295d 347a54a3
......@@ -10,7 +10,7 @@ from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights_distributed
from lightx2v.utils.utils import load_weights
from .tokenizer import HuggingfaceTokenizer
......@@ -571,8 +571,8 @@ class T5EncoderModel:
.requires_grad_(False)
)
weights_ditc = load_weights_distributed(self.checkpoint_path)
model.load_state_dict(weights_ditc)
weights_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict)
self.model = model
if shard_fn is not None:
......
......@@ -11,7 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights_distributed
from lightx2v.utils.utils import load_weights
__all__ = [
"XLMRobertaCLIP",
......@@ -418,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, seq_p_group=None):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, seq_p_group=None):
self.dtype = dtype
self.device = device
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
self.seq_p_group = seq_p_group
if self.quantized:
......@@ -434,28 +436,21 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights_distributed(self.checkpoint_path)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual")
self.model.load_state_dict(weight_dict)
def visual(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
def visual(self, videos):
if self.cpu_offload:
self.to_cuda()
use_31_block = getattr(args, "use_31_block", True)
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u, size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=use_31_block)
out = self.model.visual(videos, use_31_block=self.use_31_block)
if hasattr(args, "cpu_offload") and args.cpu_offload:
if self.cpu_offload:
self.to_cpu()
return out
......
......@@ -98,6 +98,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
......@@ -115,10 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens,
freqs,
context,
audio_dit_blocks,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights()
return x
......@@ -145,9 +144,8 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens,
freqs,
context,
audio_dit_blocks,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights()
......@@ -164,6 +162,7 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
......@@ -189,9 +188,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
......@@ -216,6 +213,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
......@@ -251,9 +249,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
......@@ -290,16 +286,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
@torch._dynamo.disable
def _apply_audio_dit(self, x, block_idx, grid_sizes, audio_dit_blocks):
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
......@@ -309,13 +298,11 @@ class WanTransformerInfer(BaseTransformerInfer):
seq_lens,
freqs,
context,
audio_dit_blocks,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0],
embed0,
......@@ -331,7 +318,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
return x
def infer_modulation(self, weights, embed0):
......@@ -516,12 +503,19 @@ class WanTransformerInfer(BaseTransformerInfer):
return y
def post_process(self, x, y, c_gate_msa):
def post_process(self, x, y, c_gate_msa, grid_sizes, audio_dit_blocks=None):
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else:
x.add_(y * c_gate_msa.squeeze())
# Apply audio_dit if available
if audio_dit_blocks is not None and hasattr(self, "block_idx"):
for ipa_out in audio_dit_blocks:
if self.block_idx in ipa_out:
cur_modify = ipa_out[self.block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
if self.clean_cuda_cache:
del y, c_gate_msa
torch.cuda.empty_cache()
......
......@@ -105,6 +105,18 @@ class WanModel:
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _should_load_weights(self):
"""Determine if current rank should load weights from disk."""
if self.config.get("device_mesh") is None:
# Single GPU mode
return True
elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads
if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {
......@@ -190,64 +202,31 @@ class WanModel:
}
if weight_dict is None:
is_weight_loader = False
if self.config.get("device_mesh") is None:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
elif dist.is_initialized():
if dist.get_rank() == 0:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
cpu_weight_dict = {}
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
# Load quantized weights
if not self.config.get("lazy_load", False):
cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is None: # 单卡模式
self.original_weight_dict = {}
init_device = "cpu" if self.cpu_offload else "cuda"
for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to(init_device, non_blocking=True)
else:
global_src_rank = 0
meta_dict = {}
if is_weight_loader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
self.original_weight_dict = {}
for key, meta in synced_meta_dict.items():
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
if self.config.get("device_mesh") is not None:
weight_dict = self._distribute_weights_multi_gpu(weight_dict, is_weight_loader)
dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank)
if is_weight_loader:
del cpu_weight_dict
self.original_weight_dict = weight_dict
else:
self.original_weight_dict = weight_dict
# init weights
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
......@@ -255,6 +234,52 @@ class WanModel:
del self.original_weight_dict
torch.cuda.empty_cache()
def _distribute_weights_multi_gpu(self, weight_dict, is_weight_loader):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank = 0
# Determine target device for distribution
target_device = "cpu" if self.cpu_offload else "cuda"
if is_weight_loader:
# Create metadata for broadcasting
meta_dict = {}
for key, tensor in weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
# Broadcast metadata to all ranks
obj_list = [meta_dict]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
else:
# Non-loader ranks receive metadata
obj_list = [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
# Create empty tensors on target device for all ranks
distributed_weight_dict = {}
for key, meta in synced_meta_dict.items():
distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
# Synchronize before broadcasting
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
# Broadcast weights from rank 0 to all ranks
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
# Copy weights to broadcast tensor
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)
# Broadcast to all ranks
dist.broadcast(distributed_weight_dict[key], src=global_src_rank)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
......
......@@ -668,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
clip_encoder_out = self.image_encoder.visual([cond_frms]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
# vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
......
......@@ -84,6 +84,8 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
use_31_block=self.config.get("use_31_block", True),
)
return image_encoder
......@@ -233,7 +235,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE())
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
......
......@@ -7,7 +7,7 @@ import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from lightx2v.utils.utils import load_weights_distributed
from lightx2v.utils.utils import load_weights
__all__ = [
"WanVAE",
......@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, cpu_offload=False, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -780,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model = WanVAE_(**cfg)
# load checkpoint
weights_dict = load_weights_distributed(pretrained_path)
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict, assign=True)
return model
......@@ -846,16 +845,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = (
_video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
seq_p_group=seq_p_group,
)
.eval()
.requires_grad_(False)
.to(device)
)
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, seq_p_group=seq_p_group, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
def current_device(self):
return next(self.model.parameters()).device
......
......@@ -6,6 +6,8 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from lightx2v.utils.utils import load_weights
__all__ = [
"Wan2_2_VAE",
]
......@@ -806,7 +808,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, **kwargs):
# params
cfg = dict(
dim=dim,
......@@ -825,7 +827,8 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# load checkpoint
logging.info(f"loading {pretrained_path}")
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict)
return model
......@@ -955,6 +958,7 @@ class Wan2_2_VAE:
dim=c_dim,
dim_mult=dim_mult,
temperal_downsample=temperal_downsample,
cpu_offload=cpu_offload,
)
.eval()
.requires_grad_(False)
......
......@@ -324,51 +324,73 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path):
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized():
# Single GPU mode
logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False
# Multi-GPU mode
is_weight_loader = False
current_rank = dist.get_rank()
if current_rank == 0:
is_leader = True
is_weight_loader = True
cpu_weight_dict = {}
if is_leader: ##rank0在 CPU 上加载完整的权重字典
if is_weight_loader: # rank0在 CPU 上加载完整的权重字典
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
# 同步字典的结构
meta_dict = {}
if is_leader:
if is_weight_loader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_leader else [None]
obj_list = [meta_dict] if is_weight_loader else [None]
# 获取rank0的全局 rank 用于广播
src_global_rank = 0
dist.broadcast_object_list(obj_list, src=src_global_rank)
synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典
target_device = torch.device(f"cuda:{current_rank}")
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(device_ids=[torch.cuda.current_device()])
# 根据offload配置决定目标设备
if cpu_offload:
# Multi-GPU + offload: weights on CPU
target_device = "cpu"
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# CPU分发使用普通barrier
dist.barrier()
else:
# Multi-GPU + non-offload: weights on GPU
target_device = torch.device(f"cuda:{current_rank}")
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# GPU分发使用CUDA barrier
dist.barrier(device_ids=[torch.cuda.current_device()])
# 广播权重
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key]
if is_leader:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
tensor_to_broadcast = distributed_weight_dict[key]
if is_weight_loader:
# rank0将CPU权重拷贝到目标设备,准备广播
if cpu_offload:
# CPU模式:直接复制
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
else:
# GPU模式:先复制到当前GPU,再广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
# 广播到所有ranks
dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_leader:
if is_weight_loader:
del cpu_weight_dict
return gpu_weight_dict
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment