"encoding/datasets/cityscapescoarse.py" did not exist on "b4ff422b498e1b640ddb874a4245f6a39985b264"
Commit 6de0996c authored by wangshankun's avatar wangshankun
Browse files

Feature:Implement single-loading multi-GPU broadcast model loading logic

parent b072a45f
......@@ -10,6 +10,7 @@ from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights_distributed
from .tokenizer import HuggingfaceTokenizer
......@@ -539,6 +540,7 @@ class T5EncoderModel:
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
seq_p_group=None,
):
self.text_len = text_len
self.dtype = dtype
......@@ -569,9 +571,8 @@ class T5EncoderModel:
.requires_grad_(False)
)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
logger.info(f"End Loading weights from {self.checkpoint_path}")
weights_ditc = load_weights_distributed(self.checkpoint_path, seq_p_group)
model.load_state_dict(weights_ditc)
self.model = model
if shard_fn is not None:
......
......@@ -11,6 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights_distributed
__all__ = [
"XLMRobertaCLIP",
......@@ -417,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, seq_p_group=None):
self.dtype = dtype
self.device = device
self.quantized = clip_quantized
self.seq_p_group = seq_p_group
if self.quantized:
self.checkpoint_path = clip_quantized_ckpt
else:
......@@ -431,15 +434,15 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
weight_dict = load_weights_distributed(self.checkpoint_path, seq_p_group=self.seq_p_group)
# weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
self.model.load_state_dict(weight_dict)
logger.info(f"End Loading weights from {self.checkpoint_path}")
def visual(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
......
......@@ -12,6 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from loguru import logger
from transformers import AutoModel
from lightx2v.utils.envs import *
......@@ -64,12 +65,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
is_leader = True
if is_leader:
logger.info(f"Loading model state from {in_path}")
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
# 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized():
dist.barrier(group=seq_p_group)
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
......@@ -79,7 +81,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group)
elif dist.is_initialized():
dist.barrier()
dist.barrier(device_ids=[torch.cuda.current_device()])
for param in model.parameters():
dist.broadcast(param.data, src=0)
for buffer in model.buffers():
......
......@@ -19,8 +19,8 @@ class WanAudioModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
......
......@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
......@@ -41,11 +41,12 @@ class WanModel:
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
def __init__(self, model_path, config, device, seq_p_group=None):
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.seq_p_group = seq_p_group
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
......@@ -187,16 +188,61 @@ class WanModel:
"img_emb.proj.0",
"img_emb.proj.4",
}
if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
is_weight_loader = False
if self.seq_p_group is None:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
elif dist.is_initialized():
if dist.get_rank(group=self.seq_p_group) == 0:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
cpu_weight_dict = {}
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
self.original_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if not self.config.get("lazy_load", False):
cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.seq_p_group is None: # 单卡模式
self.original_weight_dict = {}
for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to("cuda", non_blocking=True)
else:
seq_p_group = self.seq_p_group
global_src_rank = dist.get_process_group_ranks(seq_p_group)[0]
meta_dict = {}
if is_weight_loader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank, group=seq_p_group)
synced_meta_dict = obj_list[0]
self.original_weight_dict = {}
for key, meta in synced_meta_dict.items():
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank, group=seq_p_group)
if is_weight_loader:
del cpu_weight_dict
else:
self.original_weight_dict = weight_dict
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
......
......@@ -610,7 +610,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......
......@@ -29,6 +29,7 @@ class WanCausVidRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......
......@@ -21,6 +21,7 @@ class WanDistillRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......@@ -90,6 +91,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs:
......@@ -104,6 +106,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
if use_low_lora:
......
......@@ -34,12 +34,18 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
device_mesh = self.config.get("device_mesh")
if device_mesh is not None:
self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def load_transformer(self):
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......@@ -77,6 +83,7 @@ class WanRunner(DefaultRunner):
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
)
return image_encoder
......@@ -118,6 +125,7 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
seq_p_group=self.seq_p_group,
)
text_encoders = [text_encoder]
return text_encoders
......@@ -128,6 +136,7 @@ class WanRunner(DefaultRunner):
"device": self.init_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
}
if self.config.task != "i2v":
return None
......
......@@ -8,6 +8,8 @@ import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from lightx2v.utils.utils import load_weights_distributed
__all__ = [
"WanVAE",
]
......@@ -758,7 +760,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -780,7 +782,9 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
# load checkpoint
logging.info(f"loading {pretrained_path}")
model.load_state_dict(torch.load(pretrained_path, map_location=device, weights_only=True), assign=True)
weights_dict = load_weights_distributed(pretrained_path, seq_p_group)
model.load_state_dict(weights_dict, assign=True)
return model
......@@ -794,6 +798,7 @@ class WanVAE:
device="cuda",
parallel=False,
use_tiling=False,
seq_p_group=None,
):
self.dtype = dtype
self.device = device
......@@ -845,6 +850,7 @@ class WanVAE:
_video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
seq_p_group=seq_p_group,
)
.eval()
.requires_grad_(False)
......
......@@ -8,6 +8,7 @@ import imageio
import imageio_ffmpeg as ffmpeg
import numpy as np
import torch
import torch.distributed as dist
import torchvision
from einops import rearrange
from loguru import logger
......@@ -272,7 +273,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
for path in paths_to_check:
if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
......@@ -292,7 +292,6 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
......@@ -325,6 +324,53 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path, seq_p_group=None):
if seq_p_group is None or not dist.is_initialized():
logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False
current_rank = dist.get_rank(seq_p_group)
if current_rank == 0:
is_leader = True
cpu_weight_dict = {}
if is_leader: ##rank0在 CPU 上加载完整的权重字典
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# 同步字典的结构
meta_dict = {}
if is_leader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_leader else [None]
# 获取rank0的全局 rank 用于广播
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
dist.broadcast_object_list(obj_list, src=src_global_rank, group=seq_p_group)
synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典
target_device = torch.device(f"cuda:{current_rank}")
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key]
if is_leader:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=src_global_rank, group=seq_p_group)
if is_leader:
del cpu_weight_dict
return gpu_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, torch.Tensor)
out = torch.ones_like(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment