Unverified Commit 50a9aa94 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #59 from laekov/cope-with-pipeline

Use moe_group instead of world for MoE
parents 55f8ca7d 59913cca
...@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -27,6 +27,7 @@ class DistributedGroupedDataParallel(nn.Module):
module, module,
mp_group=None, mp_group=None,
dp_group=None, dp_group=None,
moe_group=None,
world_group=None, world_group=None,
auto_allreduce=False, auto_allreduce=False,
): ):
...@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -42,6 +43,10 @@ class DistributedGroupedDataParallel(nn.Module):
self.comms["dp"] = dp_group self.comms["dp"] = dp_group
else: else:
self.comms["dp"] = get_torch_default_comm() self.comms["dp"] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None: if world_group is None:
self.comms["world"] = get_torch_default_comm() self.comms["world"] = get_torch_default_comm()
else: else:
......
...@@ -10,13 +10,13 @@ import fmoe_cuda ...@@ -10,13 +10,13 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def _ensure_nccl(t, comm=None): def ensure_comm(t, comm):
if comm is None: if comm is None:
comm = get_torch_default_comm() comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, require_pos=True): def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32 num_expert * world_size, device=gate.device, dtype=torch.int32
...@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True): ...@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
local_expert_count = local_expert_count.long() local_expert_count = local_expert_count.long()
if world_size > 1: if world_size > 1:
_ensure_nccl(gate)
global_expert_count = fmoe_cuda.expert_exchange( global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size local_expert_count, num_expert, world_size
) )
...@@ -41,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True): ...@@ -41,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
return pos, local_expert_count, global_expert_count return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm=None): def prepare_forward(gate, num_expert, world_size, comm):
r""" r"""
Prepare necessary information from gate output for MoE computation. Prepare necessary information from gate output for MoE computation.
...@@ -52,11 +51,8 @@ def prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -52,11 +51,8 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
world_size: number of workers that hold different experts. world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
if world_size > 1:
_ensure_nccl(gate, comm=comm)
pos, local_expert_count, global_expert_count = count_by_gate(gate, pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size) num_expert, world_size, comm)
with torch.no_grad(): with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0) num_expert).sum(dim=0)
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import math import math
from .functions import prepare_forward from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
...@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm): ...@@ -74,7 +74,8 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm) setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
comm=None):
r""" r"""
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
...@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -92,7 +93,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
global_expert_count, global_expert_count,
fwd_expert_count, fwd_expert_count,
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size) ) = prepare_forward(gate, num_expert, world_size, comm)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
...@@ -138,6 +139,7 @@ class FMoE(nn.Module): ...@@ -138,6 +139,7 @@ class FMoE(nn.Module):
d_model=1024, d_model=1024,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
moe_group=None,
top_k=2, top_k=2,
gate=NaiveGate, gate=NaiveGate,
expert=None, expert=None,
...@@ -171,6 +173,7 @@ class FMoE(nn.Module): ...@@ -171,6 +173,7 @@ class FMoE(nn.Module):
self.gate_hook = gate_hook self.gate_hook = gate_hook
self.mask = mask self.mask = mask
self.mask_dict = mask_dict self.mask_dict = mask_dict
self.moe_group = moe_group
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
r""" r"""
...@@ -201,7 +204,7 @@ class FMoE(nn.Module): ...@@ -201,7 +204,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
else: else:
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "world") mark_module_parallel_comm(self.gate, "moe")
def forward(self, inp): def forward(self, inp):
r""" r"""
...@@ -209,6 +212,8 @@ class FMoE(nn.Module): ...@@ -209,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
""" """
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.mp_size > 1: if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group) inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
...@@ -224,7 +229,7 @@ class FMoE(nn.Module): ...@@ -224,7 +229,7 @@ class FMoE(nn.Module):
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
inp, inp,
gate_top_k_idx, gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size self.expert_fn, self.num_expert, self.world_size, self.moe_group
) )
# recover deleted tensors # recover deleted tensors
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from fmoe.balance import reset_balance_profile from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {} balance_dict = {}
...@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func): ...@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch.cat(loss_list).mean() * args.balance_loss_weight torch.cat(loss_list).mean() * args.balance_loss_weight
) )
# avarage across world group # avarage across moe group
world_group = get_torch_default_comm() moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=world_group) world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach() averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=world_group) torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size averaged_bal_loss /= world_size
loss += bal_loss loss += bal_loss
......
...@@ -4,6 +4,18 @@ distributed support for Megatron ...@@ -4,6 +4,18 @@ distributed support for Megatron
from fmoe.distributed import DistributedGroupedDataParallel from fmoe.distributed import DistributedGroupedDataParallel
_moe_group = None
def set_moe_group(moe_group):
global _moe_group
_moe_group = moe_group
def get_moe_group():
return _moe_group
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
r""" r"""
A wrapper that is used to replace the DDP module provided by Megatron, which A wrapper that is used to replace the DDP module provided by Megatron, which
...@@ -18,6 +30,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel): ...@@ -18,6 +30,7 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
module, module,
mp_group=mpu.get_model_parallel_group(), mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group(), dp_group=mpu.get_data_parallel_group(),
moe_group=_moe_group
) )
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook from .balance import generate_megatron_gate_hook
from .distributed import set_moe_group
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
...@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -74,7 +75,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
""" """
def __init__(self, args, group, layer_idx): def __init__(self, args, mp_group, moe_group, layer_idx):
assert ( assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0 == 0
...@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -82,7 +83,7 @@ class MegatronMLP(FMoETransformerMLP):
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
else: else:
world_size = args.world_size world_size = args.tensor_model_parallel_size * args.data_parallel_size
gate = None gate = None
if not args.balance_strategy or args.balance_strategy == "naive": if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
...@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -102,13 +103,15 @@ class MegatronMLP(FMoETransformerMLP):
gate = SwitchGate gate = SwitchGate
else: else:
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
args.num_experts, args.num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, world_size=world_size,
mp_group=group, mp_group=mp_group,
moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp", expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook( gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size layer_idx, args.num_experts * world_size
...@@ -187,8 +190,24 @@ def fmoefy( ...@@ -187,8 +190,24 @@ def fmoefy(
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
if hasattr(mpu, 'get_tensor_model_parallel_group'):
mp_group = mpu.get_tensor_model_parallel_group()
else:
# For compatibility to older versions of Megatron-LM
mp_group = mpu.get_model_parallel_group()
if args.pipeline_model_parallel_size == 1:
moe_group = None
else:
# Create a comm prependicular to pipeline group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
moe_group = group
set_moe_group(moe_group)
for idx, l in enumerate(model.language_model.transformer.layers): for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx) l.mlp = MegatronMLP(args, mp_group, moe_group, idx)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
......
...@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE): ...@@ -44,6 +44,7 @@ class FMoETransformerMLP(FMoE):
d_hidden=4096, d_hidden=4096,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
moe_group=None,
activation=torch.nn.GELU(), activation=torch.nn.GELU(),
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
...@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE): ...@@ -59,6 +60,7 @@ class FMoETransformerMLP(FMoE):
top_k=top_k, top_k=top_k,
world_size=world_size, world_size=world_size,
mp_group=mp_group, mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook, gate_hook=gate_hook,
mask=mask, mask=mask,
mask_dict=mask_dict mask_dict=mask_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment