Commit daa06243 authored by wangshankun's avatar wangshankun
Browse files

BugFix:1.cfg并行和模型加载group冲突2.offload和广播功能冲突3.savevidoe并行中多次save

parent 64948a2e
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"adaptive_resize": true,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
},
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"adaptive_resize": true,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
......@@ -571,7 +571,7 @@ class T5EncoderModel:
.requires_grad_(False)
)
weights_ditc = load_weights_distributed(self.checkpoint_path, seq_p_group)
weights_ditc = load_weights_distributed(self.checkpoint_path)
model.load_state_dict(weights_ditc)
self.model = model
......
......@@ -434,8 +434,7 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights_distributed(self.checkpoint_path, seq_p_group=self.seq_p_group)
# weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
weight_dict = load_weights_distributed(self.checkpoint_path)
keys = list(weight_dict.keys())
for key in keys:
......
......@@ -53,13 +53,13 @@ def load_pt_safetensors(in_path: str):
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, seq_p_group=None):
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model = model.to("cuda")
# 确定当前进程是否是(负责加载权重)
is_leader = False
if seq_p_group is not None and dist.is_initialized():
group_rank = dist.get_rank(group=seq_p_group)
if group_rank == 0:
if dist.is_initialized():
current_rank = dist.get_rank()
if current_rank == 0:
is_leader = True
elif not dist.is_initialized() or dist.get_rank() == 0:
is_leader = True
......@@ -70,16 +70,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
model.load_state_dict(state_dict, strict=strict)
# 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized():
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
if dist.is_initialized():
dist.barrier(device_ids=[torch.cuda.current_device()])
src_global_rank = 0
for param in model.parameters():
dist.broadcast(param.data, src=src_global_rank, group=seq_p_group)
dist.broadcast(param.data, src=src_global_rank)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group)
dist.broadcast(buffer.data, src=src_global_rank)
elif dist.is_initialized():
dist.barrier(device_ids=[torch.cuda.current_device()])
for param in model.parameters():
......
......@@ -191,11 +191,11 @@ class WanModel:
if weight_dict is None:
is_weight_loader = False
if self.seq_p_group is None:
if self.config.get("device_mesh") is None:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
elif dist.is_initialized():
if dist.get_rank(group=self.seq_p_group) == 0:
if dist.get_rank() == 0:
is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}")
......@@ -209,13 +209,13 @@ class WanModel:
else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.seq_p_group is None: # 单卡模式
if self.config.get("device_mesh") is None: # 单卡模式
self.original_weight_dict = {}
init_device = "cpu" if self.cpu_offload else "cuda"
for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to("cuda", non_blocking=True)
self.original_weight_dict[key] = tensor.to(init_device, non_blocking=True)
else:
seq_p_group = self.seq_p_group
global_src_rank = dist.get_process_group_ranks(seq_p_group)[0]
global_src_rank = 0
meta_dict = {}
if is_weight_loader:
......@@ -223,20 +223,20 @@ class WanModel:
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank, group=seq_p_group)
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
self.original_weight_dict = {}
for key, meta in synced_meta_dict.items():
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank, group=seq_p_group)
dist.broadcast(tensor_to_broadcast, src=global_src_rank)
if is_weight_loader:
del cpu_weight_dict
......@@ -252,6 +252,9 @@ class WanModel:
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
......
......@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torchaudio as ta
from PIL import Image
from einops import rearrange
......@@ -432,7 +433,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
seq_p_group = None
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False, seq_p_group=seq_p_group)
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
......@@ -564,6 +565,7 @@ class WanAudioRunner(WanRunner): # type:ignore
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Save video if requested
if (self.config.get("device_mesh") is not None and dist.get_rank() == 0) or self.config.get("device_mesh") is None:
if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.distributed as dist
......@@ -781,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f"loading {pretrained_path}")
weights_dict = load_weights_distributed(pretrained_path, seq_p_group)
weights_dict = load_weights_distributed(pretrained_path)
model.load_state_dict(weights_dict, assign=True)
......
......@@ -324,13 +324,13 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path, seq_p_group=None):
if seq_p_group is None or not dist.is_initialized():
def load_weights_distributed(checkpoint_path):
if not dist.is_initialized():
logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False
current_rank = dist.get_rank(seq_p_group)
current_rank = dist.get_rank()
if current_rank == 0:
is_leader = True
......@@ -348,22 +348,22 @@ def load_weights_distributed(checkpoint_path, seq_p_group=None):
obj_list = [meta_dict] if is_leader else [None]
# 获取rank0的全局 rank 用于广播
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
dist.broadcast_object_list(obj_list, src=src_global_rank, group=seq_p_group)
src_global_rank = 0
dist.broadcast_object_list(obj_list, src=src_global_rank)
synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典
target_device = torch.device(f"cuda:{current_rank}")
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()])
dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key]
if is_leader:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=src_global_rank, group=seq_p_group)
dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_leader:
del cpu_weight_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment