Commit daa06243 authored by wangshankun's avatar wangshankun
Browse files

BugFix:1.cfg并行和模型加载group冲突2.offload和广播功能冲突3.savevidoe并行中多次save

parent 64948a2e
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"adaptive_resize": true,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
},
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"adaptive_resize": true,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
}
...@@ -571,7 +571,7 @@ class T5EncoderModel: ...@@ -571,7 +571,7 @@ class T5EncoderModel:
.requires_grad_(False) .requires_grad_(False)
) )
weights_ditc = load_weights_distributed(self.checkpoint_path, seq_p_group) weights_ditc = load_weights_distributed(self.checkpoint_path)
model.load_state_dict(weights_ditc) model.load_state_dict(weights_ditc)
self.model = model self.model = model
......
...@@ -434,8 +434,7 @@ class CLIPModel: ...@@ -434,8 +434,7 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
) )
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights_distributed(self.checkpoint_path, seq_p_group=self.seq_p_group) weight_dict = load_weights_distributed(self.checkpoint_path)
# weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys()) keys = list(weight_dict.keys())
for key in keys: for key in keys:
......
...@@ -53,13 +53,13 @@ def load_pt_safetensors(in_path: str): ...@@ -53,13 +53,13 @@ def load_pt_safetensors(in_path: str):
return state_dict return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, seq_p_group=None): def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model = model.to("cuda") model = model.to("cuda")
# 确定当前进程是否是(负责加载权重) # 确定当前进程是否是(负责加载权重)
is_leader = False is_leader = False
if seq_p_group is not None and dist.is_initialized(): if dist.is_initialized():
group_rank = dist.get_rank(group=seq_p_group) current_rank = dist.get_rank()
if group_rank == 0: if current_rank == 0:
is_leader = True is_leader = True
elif not dist.is_initialized() or dist.get_rank() == 0: elif not dist.is_initialized() or dist.get_rank() == 0:
is_leader = True is_leader = True
...@@ -70,16 +70,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se ...@@ -70,16 +70,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
# 将模型状态从领导者同步到组内所有其他进程 # 将模型状态从领导者同步到组内所有其他进程
if seq_p_group is not None and dist.is_initialized(): if dist.is_initialized():
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
src_global_rank = 0
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0]
for param in model.parameters(): for param in model.parameters():
dist.broadcast(param.data, src=src_global_rank, group=seq_p_group) dist.broadcast(param.data, src=src_global_rank)
for buffer in model.buffers(): for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank, group=seq_p_group) dist.broadcast(buffer.data, src=src_global_rank)
elif dist.is_initialized(): elif dist.is_initialized():
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
for param in model.parameters(): for param in model.parameters():
......
...@@ -191,11 +191,11 @@ class WanModel: ...@@ -191,11 +191,11 @@ class WanModel:
if weight_dict is None: if weight_dict is None:
is_weight_loader = False is_weight_loader = False
if self.seq_p_group is None: if self.config.get("device_mesh") is None:
is_weight_loader = True is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}") logger.info(f"Loading original dit model from {self.model_path}")
elif dist.is_initialized(): elif dist.is_initialized():
if dist.get_rank(group=self.seq_p_group) == 0: if dist.get_rank() == 0:
is_weight_loader = True is_weight_loader = True
logger.info(f"Loading original dit model from {self.model_path}") logger.info(f"Loading original dit model from {self.model_path}")
...@@ -209,13 +209,13 @@ class WanModel: ...@@ -209,13 +209,13 @@ class WanModel:
else: else:
cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) cpu_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.seq_p_group is None: # 单卡模式 if self.config.get("device_mesh") is None: # 单卡模式
self.original_weight_dict = {} self.original_weight_dict = {}
init_device = "cpu" if self.cpu_offload else "cuda"
for key, tensor in cpu_weight_dict.items(): for key, tensor in cpu_weight_dict.items():
self.original_weight_dict[key] = tensor.to("cuda", non_blocking=True) self.original_weight_dict[key] = tensor.to(init_device, non_blocking=True)
else: else:
seq_p_group = self.seq_p_group global_src_rank = 0
global_src_rank = dist.get_process_group_ranks(seq_p_group)[0]
meta_dict = {} meta_dict = {}
if is_weight_loader: if is_weight_loader:
...@@ -223,20 +223,20 @@ class WanModel: ...@@ -223,20 +223,20 @@ class WanModel:
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict] if is_weight_loader else [None] obj_list = [meta_dict] if is_weight_loader else [None]
dist.broadcast_object_list(obj_list, src=global_src_rank, group=seq_p_group) dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
self.original_weight_dict = {} self.original_weight_dict = {}
for key, meta in synced_meta_dict.items(): for key, meta in synced_meta_dict.items():
self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda") self.original_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device="cuda")
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = self.original_weight_dict[key] tensor_to_broadcast = self.original_weight_dict[key]
if is_weight_loader: if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=global_src_rank, group=seq_p_group) dist.broadcast(tensor_to_broadcast, src=global_src_rank)
if is_weight_loader: if is_weight_loader:
del cpu_weight_dict del cpu_weight_dict
...@@ -252,6 +252,9 @@ class WanModel: ...@@ -252,6 +252,9 @@ class WanModel:
self.post_weight.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
......
...@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torchaudio as ta import torchaudio as ta
from PIL import Image from PIL import Image
from einops import rearrange from einops import rearrange
...@@ -432,7 +433,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -432,7 +433,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
seq_p_group = None seq_p_group = None
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False, seq_p_group=seq_p_group) audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
self._audio_adapter_pipe = AudioAdapterPipe( self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
...@@ -564,8 +565,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -564,8 +565,9 @@ class WanAudioRunner(WanRunner): # type:ignore
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr} comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Save video if requested # Save video if requested
if save_video and self.config.get("save_video_path", None): if (self.config.get("device_mesh") is not None and dist.get_rank() == 0) or self.config.get("device_mesh") is None:
self._save_video_with_audio(comfyui_images, merge_audio, target_fps) if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
# Final cleanup # Final cleanup
self.end_run() self.end_run()
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -781,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, ...@@ -781,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model = WanVAE_(**cfg) model = WanVAE_(**cfg)
# load checkpoint # load checkpoint
logging.info(f"loading {pretrained_path}") weights_dict = load_weights_distributed(pretrained_path)
weights_dict = load_weights_distributed(pretrained_path, seq_p_group)
model.load_state_dict(weights_dict, assign=True) model.load_state_dict(weights_dict, assign=True)
......
...@@ -324,13 +324,13 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None): ...@@ -324,13 +324,13 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_weights_distributed(checkpoint_path, seq_p_group=None): def load_weights_distributed(checkpoint_path):
if seq_p_group is None or not dist.is_initialized(): if not dist.is_initialized():
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True) return torch.load(checkpoint_path, map_location="cpu", weights_only=True)
is_leader = False is_leader = False
current_rank = dist.get_rank(seq_p_group) current_rank = dist.get_rank()
if current_rank == 0: if current_rank == 0:
is_leader = True is_leader = True
...@@ -348,22 +348,22 @@ def load_weights_distributed(checkpoint_path, seq_p_group=None): ...@@ -348,22 +348,22 @@ def load_weights_distributed(checkpoint_path, seq_p_group=None):
obj_list = [meta_dict] if is_leader else [None] obj_list = [meta_dict] if is_leader else [None]
# 获取rank0的全局 rank 用于广播 # 获取rank0的全局 rank 用于广播
src_global_rank = dist.get_process_group_ranks(seq_p_group)[0] src_global_rank = 0
dist.broadcast_object_list(obj_list, src=src_global_rank, group=seq_p_group) dist.broadcast_object_list(obj_list, src=src_global_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
# 所有进程所在的GPU上创建空的权重字典 # 所有进程所在的GPU上创建空的权重字典
target_device = torch.device(f"cuda:{current_rank}") target_device = torch.device(f"cuda:{current_rank}")
gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} gpu_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(group=seq_p_group, device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = gpu_weight_dict[key] tensor_to_broadcast = gpu_weight_dict[key]
if is_leader: if is_leader:
# rank0将CPU权重拷贝到目标GPU,准备广播 # rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=src_global_rank, group=seq_p_group) dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_leader: if is_leader:
del cpu_weight_dict del cpu_weight_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment