Commit 67c667f2 authored by Rick Ho's avatar Rick Ho
Browse files

ddp module for sophiscated hybrid parallel

parent ea66e5e5
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class DistributedGroupedDataParallel(nn.Module):
def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
auto_allreduce=False):
assert not auto_allreduce, 'Automatic all-reduce is not implemented yet'
super(DistributedGroupedDataParallel, self).__init__()
self.module = module
self.comms = dict()
if mp_group is not None:
self.comms['mp'] = mp_group
if dp_group is not None:
self.comms['dp'] = dp_group
else:
self.comms['dp'] = torch.distributed.distributed_c10d._default_pg
if world_group is None:
self.comms['world'] = torch.distributed.distributed_c10d._default_pg
else:
self.comms['world'] = world_group
def allreduce_params(no_scale=False, reduce_after=False,
fp32_allreduce=False):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, 'parallel_method'):
pm = p.parallel_method
else:
pm = 'dp'
group_key = (pm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
else:
groups[group_key].append(p)
for pm, dtype in groups:
if pm not in self.comms:
continue
group = groups[pm, dtype]
comm = self.comms[pm]
grads = [p.grad.data for p in group]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= comm.size()
torch.distributed.all_reduce(coalesced, group=comm)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= comm.size()
synced = _unflatten_dense_tensors(coalesced, grads)
for g, s in zip(grads, synced):
g.copy_(s)
self.allreduce_params = allreduce_params
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
from .layers import FMoETransformerMLP from .layers import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
def create_moe_mlp(args, model_parallel_rank, group): def create_moe_mlp(args, model_parallel_rank, group):
assert ( assert (
args.seq_length * args.batch_size % args.model_parallel_size == 0 args.seq_length * args.batch_size % args.model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if args.model_parallel_size == 1: if not args.distributed_experts:
world_size = 1 world_size = 1
else: else:
world_size = args.world_size world_size = args.world_size
...@@ -21,7 +21,7 @@ def create_moe_mlp(args, model_parallel_rank, group): ...@@ -21,7 +21,7 @@ def create_moe_mlp(args, model_parallel_rank, group):
return fmoe return fmoe
def fmoefy(model, num_experts=None): def fmoefy(model, num_experts=None, distributed_experts=True):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
args = get_args() args = get_args()
...@@ -30,8 +30,23 @@ def fmoefy(model, num_experts=None): ...@@ -30,8 +30,23 @@ def fmoefy(model, num_experts=None):
assert ( assert (
'num_experts' in args 'num_experts' in args
), 'num_experts should be specified in arguments or fmoefy function' ), 'num_experts should be specified in arguments or fmoefy function'
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers: for l in model.language_model.transformer.layers:
l.mlp = create_moe_mlp(args, l.mlp = create_moe_mlp(args,
mpu.get_model_parallel_rank(), mpu.get_model_parallel_rank(),
mpu.get_model_parallel_group()) mpu.get_model_parallel_group())
return model return model
class DistributedDataParallel(DistributedGroupedDataParallel):
def __init__(self, module):
from megatron import mpu
super(DistributedDataParallel, self).__init__(
module,
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment