Commit 18a4395c authored by Rick Ho's avatar Rick Ho
Browse files

add moe_comm

parent 55f8ca7d
...@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module):
module, module,
mp_group=None, mp_group=None,
dp_group=None, dp_group=None,
moe_group=None,
world_group=None, world_group=None,
auto_allreduce=False, auto_allreduce=False,
): ):
...@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module):
self.comms["dp"] = dp_group self.comms["dp"] = dp_group
else: else:
self.comms["dp"] = get_torch_default_comm() self.comms["dp"] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None: if world_group is None:
self.comms["world"] = get_torch_default_comm() self.comms["world"] = get_torch_default_comm()
else: else:
......
...@@ -10,13 +10,13 @@ import fmoe_cuda ...@@ -10,13 +10,13 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def _ensure_nccl(t, comm=None): def _ensure_nccl(t, comm):
if comm is None: if comm is None:
comm = get_torch_default_comm() comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, require_pos=True): def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32 num_expert * world_size, device=gate.device, dtype=torch.int32
...@@ -25,7 +25,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True): ...@@ -25,7 +25,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
local_expert_count = local_expert_count.long() local_expert_count = local_expert_count.long()
if world_size > 1: if world_size > 1:
_ensure_nccl(gate) _ensure_nccl(gate, comm)
global_expert_count = fmoe_cuda.expert_exchange( global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size local_expert_count, num_expert, world_size
) )
...@@ -41,7 +41,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True): ...@@ -41,7 +41,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
return pos, local_expert_count, global_expert_count return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm=None): def prepare_forward(gate, num_expert, world_size, comm):
r""" r"""
Prepare necessary information from gate output for MoE computation. Prepare necessary information from gate output for MoE computation.
...@@ -56,7 +56,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -56,7 +56,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
_ensure_nccl(gate, comm=comm) _ensure_nccl(gate, comm=comm)
pos, local_expert_count, global_expert_count = count_by_gate(gate, pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size) num_expert, world_size, comm)
with torch.no_grad(): with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0) num_expert).sum(dim=0)
......
...@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm): ...@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm) setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
comm=None):
r""" r"""
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
...@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
global_expert_count, global_expert_count,
fwd_expert_count, fwd_expert_count,
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size) ) = prepare_forward(gate, num_expert, world_size, comm)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
...@@ -138,6 +139,7 @@ class FMoE(nn.Module): ...@@ -138,6 +139,7 @@ class FMoE(nn.Module):
d_model=1024, d_model=1024,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
moe_group=None,
top_k=2, top_k=2,
gate=NaiveGate, gate=NaiveGate,
expert=None, expert=None,
...@@ -171,6 +173,7 @@ class FMoE(nn.Module): ...@@ -171,6 +173,7 @@ class FMoE(nn.Module):
self.gate_hook = gate_hook self.gate_hook = gate_hook
self.mask = mask self.mask = mask
self.mask_dict = mask_dict self.mask_dict = mask_dict
self.moe_group = moe_group
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
r""" r"""
...@@ -201,7 +204,7 @@ class FMoE(nn.Module): ...@@ -201,7 +204,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
else: else:
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "world") mark_module_parallel_comm(self.gate, "moe")
def forward(self, inp): def forward(self, inp):
r""" r"""
...@@ -224,7 +227,7 @@ class FMoE(nn.Module): ...@@ -224,7 +227,7 @@ class FMoE(nn.Module):
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
inp, inp,
gate_top_k_idx, gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size self.expert_fn, self.num_expert, self.world_size, self.moe_group
) )
# recover deleted tensors # recover deleted tensors
......
...@@ -4,6 +4,14 @@ distributed support for Megatron ...@@ -4,6 +4,14 @@ distributed support for Megatron
from fmoe.distributed import DistributedGroupedDataParallel from fmoe.distributed import DistributedGroupedDataParallel
_moe_group = None
def set_moe_group(moe_group):
global _moe_group
_moe_group = moe_group
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
r""" r"""
A wrapper that is used to replace the DDP module provided by Megatron, which A wrapper that is used to replace the DDP module provided by Megatron, which
...@@ -18,6 +26,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel): ...@@ -18,6 +26,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
module, module,
mp_group=mpu.get_model_parallel_group(), mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group(), dp_group=mpu.get_data_parallel_group(),
moe_group=_moe_group
) )
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook from .balance import generate_megatron_gate_hook
from .distributed import set_moe_group
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
...@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
""" """
def __init__(self, args, group, layer_idx): def __init__(self, args, mp_group, moe_group, layer_idx):
assert ( assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0 == 0
...@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP):
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
else: else:
world_size = args.world_size world_size = args.tensor_model_parallel_size * args.data_parallel_size
gate = None gate = None
if not args.balance_strategy or args.balance_strategy == "naive": if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
...@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP):
gate = SwitchGate gate = SwitchGate
else: else:
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
args.num_experts, args.num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, world_size=world_size,
mp_group=group, mp_group=mp_group,
moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp", expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook( gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size layer_idx, args.num_experts * world_size
...@@ -187,8 +190,24 @@ def fmoefy( ...@@ -187,8 +190,24 @@ def fmoefy(
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
if hasattr(mpu, 'get_tensor_model_parallel_group'):
mp_group = mpu.get_tensor_model_parallel_group()
else:
# For compatibility to older versions of Megatron-LM
mp_group = mpu.get_model_parallel_group()
if args.pipeline_model_parallel_size == 1:
moe_group = None
else:
# Create a comm prependicular to pipeline group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
moe_group = group
set_moe_group(moe_group)
for idx, l in enumerate(model.language_model.transformer.layers): for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx) l.mlp = MegatronMLP(args, mp_group, moe_group, idx)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
......
...@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE): ...@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE):
d_hidden=4096, d_hidden=4096,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
moe_group=None,
activation=torch.nn.GELU(), activation=torch.nn.GELU(),
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
...@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE): ...@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE):
top_k=top_k, top_k=top_k,
world_size=world_size, world_size=world_size,
mp_group=mp_group, mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook, gate_hook=gate_hook,
mask=mask, mask=mask,
mask_dict=mask_dict mask_dict=mask_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment