Commit ec061565 authored by gushiqiao's avatar gushiqiao
Browse files

Fix distribute offload bug.

parent 0cbd0544
...@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module): ...@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
v=v, v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max(), max_seqlen_q=q_lens.max().item(),
max_seqlen_k=k_lens.max(), max_seqlen_k=k_lens.max().item(),
dropout_p=0.0, dropout_p=0.0,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
......
...@@ -215,7 +215,7 @@ class WanModel: ...@@ -215,7 +215,7 @@ class WanModel:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None: if self.config.get("device_mesh") is not None:
weight_dict = self._distribute_weights_multi_gpu(weight_dict, is_weight_loader) weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
else: else:
...@@ -234,48 +234,61 @@ class WanModel: ...@@ -234,48 +234,61 @@ class WanModel:
del self.original_weight_dict del self.original_weight_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _distribute_weights_multi_gpu(self, weight_dict, is_weight_loader): def _load_weights_distribute(self, weight_dict, is_weight_loader):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank = 0 global_src_rank = 0
# Determine target device for distribution
target_device = "cpu" if self.cpu_offload else "cuda" target_device = "cpu" if self.cpu_offload else "cuda"
if is_weight_loader: if is_weight_loader:
# Create metadata for broadcasting
meta_dict = {} meta_dict = {}
for key, tensor in weight_dict.items(): for key, tensor in weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
# Broadcast metadata to all ranks
obj_list = [meta_dict] obj_list = [meta_dict]
dist.broadcast_object_list(obj_list, src=global_src_rank) dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
else: else:
# Non-loader ranks receive metadata
obj_list = [None] obj_list = [None]
dist.broadcast_object_list(obj_list, src=global_src_rank) dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
# Create empty tensors on target device for all ranks
distributed_weight_dict = {} distributed_weight_dict = {}
for key, meta in synced_meta_dict.items(): for key, meta in synced_meta_dict.items():
distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
# Synchronize before broadcasting
if target_device == "cuda": if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
else: else:
dist.barrier() dist.barrier()
# Broadcast weights from rank 0 to all ranks
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
if is_weight_loader: if is_weight_loader:
# Copy weights to broadcast tensor
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)
# Broadcast to all ranks if target_device == "cpu":
dist.broadcast(distributed_weight_dict[key], src=global_src_rank) if is_weight_loader:
gpu_tensor = distributed_weight_dict[key].cuda()
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
if distributed_weight_dict[key].is_pinned():
distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
else:
dist.broadcast(distributed_weight_dict[key], src=global_src_rank)
if target_device == "cuda":
torch.cuda.synchronize()
else:
for tensor in distributed_weight_dict.values():
if tensor.is_pinned():
tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict return distributed_weight_dict
......
...@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
if cpu_offload: if cpu_offload:
# Multi-GPU + offload: weights on CPU
target_device = "cpu" target_device = "cpu"
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier() dist.barrier()
else: else:
# Multi-GPU + non-offload: weights on GPU
target_device = torch.device(f"cuda:{current_rank}") target_device = torch.device(f"cuda:{current_rank}")
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
...@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
tensor_to_broadcast = distributed_weight_dict[key] tensor_to_broadcast = distributed_weight_dict[key]
if is_weight_loader: if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if cpu_offload:
if is_weight_loader:
gpu_tensor = tensor_to_broadcast.cuda()
dist.broadcast(gpu_tensor, src=src_global_rank)
tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(tensor_to_broadcast, device="cuda")
dist.broadcast(gpu_tensor, src=src_global_rank)
tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_weight_loader: if is_weight_loader:
del cpu_weight_dict del cpu_weight_dict
if cpu_offload:
torch.cuda.empty_cache()
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict return distributed_weight_dict
...@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2): ...@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
out = torch.ones_like(tensor) out = torch.ones_like(tensor)
if zero: if zero:
if generator is not None: if generator is not None:
# 生成随机数判断是否需要修改
random_num = torch.rand(1, generator=generator, device=generator.device).item() random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p: if random_num < p:
out[:, 0] = torch.zeros_like(out[:, 0]) out[:, 0] = torch.zeros_like(out[:, 0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment