Commit 6de0996c authored by wangshankun's avatar wangshankun
Browse files

Feature:Implement single-loading multi-GPU broadcast model loading logic

parent b072a45f
...@@ -10,6 +10,7 @@ from loguru import logger ...@@ -10,6 +10,7 @@ from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights_distributed
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
...@@ -539,6 +540,7 @@ class T5EncoderModel: ...@@ -539,6 +540,7 @@ class T5EncoderModel:
t5_quantized=False, t5_quantized=False,
t5_quantized_ckpt=None, t5_quantized_ckpt=None,
quant_scheme=None, quant_scheme=None,
seq_p_group=None,
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
...@@ -569,9 +571,8 @@ class T5EncoderModel: ...@@ -569,9 +571,8 @@ class T5EncoderModel:
.requires_grad_(False) .requires_grad_(False)
) )
logger.info(f"Start Loading weights from {self.checkpoint_path}") weights_ditc = load_weights_distributed(self.checkpoint_path, seq_p_group)
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)) model.load_state_dict(weights_ditc)
logger.info(f"End Loading weights from {self.checkpoint_path}")
self.model = model self.model = model
if shard_fn is not None: if shard_fn is not None:
......
...@@ -11,6 +11,7 @@ from loguru import logger ...@@ -11,6 +11,7 @@ from loguru import logger
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights_distributed
__all__ = [ __all__ = [
"XLMRobertaCLIP", "XLMRobertaCLIP",
...@@ -417,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -417,10 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, seq_p_group=None):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.quantized = clip_quantized self.quantized = clip_quantized
self.seq_p_group = seq_p_group
if self.quantized: if self.quantized:
self.checkpoint_path = clip_quantized_ckpt self.checkpoint_path = clip_quantized_ckpt
else: else:
...@@ -431,15 +434,15 @@ class CLIPModel: ...@@ -431,15 +434,15 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
) )
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True) weight_dict = load_weights_distributed(self.checkpoint_path, seq_p_group=self.seq_p_group)
# weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys()) keys = list(weight_dict.keys())
for key in keys: for key in keys:
if "textual" in key: if "textual" in key:
weight_dict.pop(key) weight_dict.pop(key)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
self.model.load_state_dict(weight_dict) self.model.load_state_dict(weight_dict)
logger.info(f"End Loading weights from {self.checkpoint_path}")
def visual(self, videos, args): def visual(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
......
...@@ -12,6 +12,7 @@ import torch.nn as nn ...@@ -12,6 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from loguru import logger
from transformers import AutoModel from transformers import AutoModel
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -64,12 +65,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se ...@@ -64,12 +65,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
is_leader = True is_leader = True
if is_leader: if is_leader:
logger.info(f"Loading model state from {in_path}")
state_dict = load_pt_safetensors(in_path) state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
# 将模型状态从领导者同步到组内所有其他进程 # 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized(): if seq_p_group is not None and dist.is_initialized():
dist.barrier(group=seq_p_group) dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0] src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
...@@ -79,7 +81,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se ...@@ -79,7 +81,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
for buffer in model.buffers(): for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group) dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group)
elif dist.is_initialized(): elif dist.is_initialized():
dist.barrier() dist.barrier(device_ids=[torch.cuda.current_device()])
for param in model.parameters(): for param in model.parameters():
dist.broadcast(param.data, src=0) dist.broadcast(param.data, src=0)
for buffer in model.buffers(): for buffer in model.buffers():
......
...@@ -19,8 +19,8 @@ class WanAudioModel(WanModel): ...@@ -19,8 +19,8 @@ class WanAudioModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device) super().__init__(model_path, config, device, seq_p_group)
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
...@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel): ...@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device) super().__init__(model_path, config, device, seq_p_group)
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
......
...@@ -19,8 +19,8 @@ class WanDistillModel(WanModel): ...@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device) super().__init__(model_path, config, device, seq_p_group)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
...@@ -41,11 +41,12 @@ class WanModel: ...@@ -41,11 +41,12 @@ class WanModel:
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, seq_p_group=None):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.seq_p_group = seq_p_group
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
...@@ -187,16 +188,61 @@ class WanModel: ...@@ -187,16 +188,61 @@ class WanModel:
"img_emb.proj.0", "img_emb.proj.0",
"img_emb.proj.4", "img_emb.proj.4",
} }
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: is_weight_loader = False
self.original_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) if self.seq_p_group is None:
else: is_weight_loader = True
if not self.config.get("lazy_load", False): logger.info(f"Loading original dit model from {self.model_path}")
self.original_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) elif dist.is_initialized():
if dist.get_rank(group=self.seq_p_group) == 0:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
cpu_weight_dict = {}
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
cpu_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
self.original_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) if not self.config.get("lazy_load", False):
cpu_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.seq_p_group is None: # 单卡模式
self.original_weight_dict = {}
for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to("cuda", non_blocking=True)
else:
seq_p_group = self.seq_p_group
global_src_rank = dist.get_process_group_ranks(seq_p_group)[0]
meta_dict = {}
if is_weight_loader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank, group=seq_p_group)
synced_meta_dict = obj_list[0]
self.original_weight_dict = {}
for key, meta in synced_meta_dict.items():
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank, group=seq_p_group)
if is_weight_loader:
del cpu_weight_dict
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
......
...@@ -610,7 +610,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -610,7 +610,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......
...@@ -29,6 +29,7 @@ class WanCausVidRunner(WanRunner): ...@@ -29,6 +29,7 @@ class WanCausVidRunner(WanRunner):
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
......
...@@ -21,6 +21,7 @@ class WanDistillRunner(WanRunner): ...@@ -21,6 +21,7 @@ class WanDistillRunner(WanRunner):
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
...@@ -90,6 +91,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -90,6 +91,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config.model_path, "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
high_lora_wrapper = WanLoraWrapper(high_noise_model) high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
...@@ -104,6 +106,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -104,6 +106,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "distill_models", "high_noise_model"), os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
if use_low_lora: if use_low_lora:
......
...@@ -34,12 +34,18 @@ from lightx2v.utils.utils import best_output_size, cache_video ...@@ -34,12 +34,18 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner): class WanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
device_mesh = self.config.get("device_mesh")
if device_mesh is not None:
self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def load_transformer(self): def load_transformer(self):
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
...@@ -77,6 +83,7 @@ class WanRunner(DefaultRunner): ...@@ -77,6 +83,7 @@ class WanRunner(DefaultRunner):
clip_quantized=clip_quantized, clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt, clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme, quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
) )
return image_encoder return image_encoder
...@@ -118,6 +125,7 @@ class WanRunner(DefaultRunner): ...@@ -118,6 +125,7 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized, t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt, t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme, quant_scheme=t5_quant_scheme,
seq_p_group=self.seq_p_group,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
return text_encoders return text_encoders
...@@ -128,6 +136,7 @@ class WanRunner(DefaultRunner): ...@@ -128,6 +136,7 @@ class WanRunner(DefaultRunner):
"device": self.init_device, "device": self.init_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
} }
if self.config.task != "i2v": if self.config.task != "i2v":
return None return None
......
...@@ -8,6 +8,8 @@ import torch.nn.functional as F ...@@ -8,6 +8,8 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from lightx2v.utils.utils import load_weights_distributed
__all__ = [ __all__ = [
"WanVAE", "WanVAE",
] ]
...@@ -758,7 +760,7 @@ class WanVAE_(nn.Module): ...@@ -758,7 +760,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, **kwargs):
""" """
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
""" """
...@@ -780,7 +782,9 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): ...@@ -780,7 +782,9 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
# load checkpoint # load checkpoint
logging.info(f"loading {pretrained_path}") logging.info(f"loading {pretrained_path}")
model.load_state_dict(torch.load(pretrained_path, map_location=device, weights_only=True), assign=True) weights_dict = load_weights_distributed(pretrained_path, seq_p_group)
model.load_state_dict(weights_dict, assign=True)
return model return model
...@@ -794,6 +798,7 @@ class WanVAE: ...@@ -794,6 +798,7 @@ class WanVAE:
device="cuda", device="cuda",
parallel=False, parallel=False,
use_tiling=False, use_tiling=False,
seq_p_group=None,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
...@@ -845,6 +850,7 @@ class WanVAE: ...@@ -845,6 +850,7 @@ class WanVAE:
_video_vae( _video_vae(
pretrained_path=vae_pth, pretrained_path=vae_pth,
z_dim=z_dim, z_dim=z_dim,
seq_p_group=seq_p_group,
) )
.eval() .eval()
.requires_grad_(False) .requires_grad_(False)
......
...@@ -8,6 +8,7 @@ import imageio ...@@ -8,6 +8,7 @@ import imageio
import imageio_ffmpeg as ffmpeg import imageio_ffmpeg as ffmpeg
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torchvision import torchvision
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
...@@ -272,7 +273,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=[" ...@@ -272,7 +273,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
for path in paths_to_check: for path in paths_to_check:
if os.path.exists(path): if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
...@@ -292,7 +292,6 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin ...@@ -292,7 +292,6 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
safetensors_pattern = os.path.join(path, "*.safetensors") safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files: if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
...@@ -325,6 +324,53 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None): ...@@ -325,6 +324,53 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path, seq_p_group=None):
if seq_p_group is None or not dist.is_initialized():
logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False
current_rank = dist.get_rank(seq_p_group)
if current_rank == 0:
is_leader = True
cpu_weight_dict = {}
if is_leader: ##rank0在 CPU 上加载完整的权重字典
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# 同步字典的结构
meta_dict = {}
if is_leader:
for key, tensor in cpu_weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_leader else [None]
# 获取rank0的全局 rank 用于广播
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
dist.broadcast_object_list(obj_list, src=src_global_rank, group=seq_p_group)
synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典
target_device = torch.device(f"cuda:{current_rank}")
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key]
if is_leader:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=src_global_rank, group=seq_p_group)
if is_leader:
del cpu_weight_dict
return gpu_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2): def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
out = torch.ones_like(tensor) out = torch.ones_like(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment