"...text-generation-inference.git" did not exist on "b3b7ea0d74627d30a0b739c5e7e4a74b4f2f4437"
Commit dd68fd78 authored by Rick Ho's avatar Rick Ho
Browse files

update ddp with get first rank and tests

parent 670e1407
...@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training ...@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .utils import get_torch_default_comm from .utils import get_torch_default_comm, get_rank_0_in_comm
class DistributedGroupedDataParallel(nn.Module): class DistributedGroupedDataParallel(nn.Module):
...@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module):
comm = self.comms[dp_comm] comm = self.comms[dp_comm]
datas = [p.data for p in group] datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas) coalesced = _flatten_dense_tensors(datas)
torch.distributed.broadcast(coalesced, 0, group=comm) torch.distributed.broadcast(coalesced,
get_rank_0_in_comm(comm), group=comm)
torch.cuda.synchronize() torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas) synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced): for d, s in zip(datas, synced):
......
r""" r"""
Utils to play with PyTorch. Utils to play with PyTorch.
""" """
import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -28,3 +29,13 @@ def get_torch_default_comm(): ...@@ -28,3 +29,13 @@ def get_torch_default_comm():
except Exception as _: except Exception as _:
pass pass
raise RuntimeError("Unsupported PyTorch version") raise RuntimeError("Unsupported PyTorch version")
def get_rank_0_in_comm(comm):
world_size = dist.get_world_size(comm)
x = torch.tensor([dist.get_rank()], dtype=torch.int64, device='cuda')
ys = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(ys, x, group=comm)
root_rank = ys[0].item()
return root_rank
...@@ -71,7 +71,7 @@ class MyMoE(FMoE): ...@@ -71,7 +71,7 @@ class MyMoE(FMoE):
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=mp_group, slice_group=mp_group,
top_k=top_k, top_k=top_k,
) )
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
...@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
model = MyModule().cuda() model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model), model_ddp = LocalDDP(deepcopy(model),
mp_group=mp_group, dp_group=dp_group, world_group=world_group) mp_group=mp_group, dp_group=dp_group, world_group=world_group)
model = deepcopy(model_ddp.module)
model.set_comm() model.set_comm()
model_ddp.module.set_comm() model_ddp.module.set_comm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment