Unverified Commit 9a01f8fa authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #120 from laekov/ddp-bcast-fix

Fix Broadcast rank bug in DGDP
parents 670e1407 dd68fd78
......@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .utils import get_torch_default_comm
from .utils import get_torch_default_comm, get_rank_0_in_comm
class DistributedGroupedDataParallel(nn.Module):
......@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module):
comm = self.comms[dp_comm]
datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas)
torch.distributed.broadcast(coalesced, 0, group=comm)
torch.distributed.broadcast(coalesced,
get_rank_0_in_comm(comm), group=comm)
torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced):
......
r"""
Utils to play with PyTorch.
"""
import torch
import torch.distributed as dist
......@@ -28,3 +29,13 @@ def get_torch_default_comm():
except Exception as _:
pass
raise RuntimeError("Unsupported PyTorch version")
def get_rank_0_in_comm(comm):
world_size = dist.get_world_size(comm)
x = torch.tensor([dist.get_rank()], dtype=torch.int64, device='cuda')
ys = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(ys, x, group=comm)
root_rank = ys[0].item()
return root_rank
......@@ -71,7 +71,7 @@ class MyMoE(FMoE):
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
slice_group=mp_group,
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
......@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model),
mp_group=mp_group, dp_group=dp_group, world_group=world_group)
model = deepcopy(model_ddp.module)
model.set_comm()
model_ddp.module.set_comm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment