Commit 0dc34857 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix distribute offload bug.

Fix distribute offload bug.
parents 0cbd0544 ec061565
......@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max(),
max_seqlen_k=k_lens.max(),
max_seqlen_q=q_lens.max().item(),
max_seqlen_k=k_lens.max().item(),
dropout_p=0.0,
softmax_scale=None,
causal=False,
......
......@@ -215,7 +215,7 @@ class WanModel:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None:
weight_dict = self._distribute_weights_multi_gpu(weight_dict, is_weight_loader)
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
self.original_weight_dict = weight_dict
else:
......@@ -234,49 +234,62 @@ class WanModel:
del self.original_weight_dict
torch.cuda.empty_cache()
def _distribute_weights_multi_gpu(self, weight_dict, is_weight_loader):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
def _load_weights_distribute(self, weight_dict, is_weight_loader):
global_src_rank = 0
# Determine target device for distribution
target_device = "cpu" if self.cpu_offload else "cuda"
if is_weight_loader:
# Create metadata for broadcasting
meta_dict = {}
for key, tensor in weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
# Broadcast metadata to all ranks
obj_list = [meta_dict]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
else:
# Non-loader ranks receive metadata
obj_list = [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
# Create empty tensors on target device for all ranks
distributed_weight_dict = {}
for key, meta in synced_meta_dict.items():
distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
# Synchronize before broadcasting
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
# Broadcast weights from rank 0 to all ranks
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
# Copy weights to broadcast tensor
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)
# Broadcast to all ranks
if target_device == "cpu":
if is_weight_loader:
gpu_tensor = distributed_weight_dict[key].cuda()
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
if distributed_weight_dict[key].is_pinned():
distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
else:
dist.broadcast(distributed_weight_dict[key], src=global_src_rank)
if target_device == "cuda":
torch.cuda.synchronize()
else:
for tensor in distributed_weight_dict.values():
if tensor.is_pinned():
tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
......
......@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
synced_meta_dict = obj_list[0]
if cpu_offload:
# Multi-GPU + offload: weights on CPU
target_device = "cpu"
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier()
else:
# Multi-GPU + non-offload: weights on GPU
target_device = torch.device(f"cuda:{current_rank}")
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(device_ids=[torch.cuda.current_device()])
......@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
tensor_to_broadcast = distributed_weight_dict[key]
if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
if cpu_offload:
if is_weight_loader:
gpu_tensor = tensor_to_broadcast.cuda()
dist.broadcast(gpu_tensor, src=src_global_rank)
tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(tensor_to_broadcast, device="cuda")
dist.broadcast(gpu_tensor, src=src_global_rank)
tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_weight_loader:
del cpu_weight_dict
if cpu_offload:
torch.cuda.empty_cache()
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
......@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
out = torch.ones_like(tensor)
if zero:
if generator is not None:
# 生成随机数判断是否需要修改
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
out[:, 0] = torch.zeros_like(out[:, 0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment