Unverified Commit 4a9ef7fd authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #81 from laekov/realmp

Model parallelism of single experts
parents 368c8e41 d30585dc
...@@ -2,6 +2,7 @@ r""" ...@@ -2,6 +2,7 @@ r"""
The fmoe package contains MoE Layers only. The fmoe package contains MoE Layers only.
""" """
from .layers import FMoELinear, FMoE from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
...@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module):
def __init__( def __init__(
self, self,
module, module,
mp_group=None,
dp_group=None,
moe_group=None,
world_group=None,
auto_allreduce=False, auto_allreduce=False,
**kwargs
): ):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet" assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
...@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module):
self.module = module self.module = module
self.comms = dict() self.comms = dict()
if mp_group is not None: for k in kwargs:
self.comms["mp"] = mp_group if k.endswith('_group'):
if dp_group is not None: self.comms[k[:-6]] = kwargs[k]
self.comms["dp"] = dp_group for k in ['dp', 'gate', 'moe', 'world']:
else: if k not in self.comms:
self.comms["dp"] = get_torch_default_comm() self.comms[k] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None:
self.comms["world"] = get_torch_default_comm()
else:
self.comms["world"] = world_group
def allreduce_params(no_scale=False, def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False): reduce_after=False, fp32_allreduce=False):
......
...@@ -132,34 +132,6 @@ class MOEScatter(Function): ...@@ -132,34 +132,6 @@ class MOEScatter(Function):
grad_in = _local_gather(local_grad_in, pos, inp_batch_size) grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function): class MOEGather(Function):
r""" r"""
Gather output samples from contiguous alone experts back to [batch x Gather output samples from contiguous alone experts back to [batch x
......
r""" r"""
Layers that FMoE provides to users FMoE core layer
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from .functions import prepare_forward, ensure_comm from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
...@@ -121,11 +67,12 @@ class FMoE(nn.Module): ...@@ -121,11 +67,12 @@ class FMoE(nn.Module):
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains * `world_size` stands for the total number of workers that contains
different experts. different experts.
* `mp_group` can be a torch's communication group, indicating that model * `slice_group` can be a torch's communication group, indicating that
parallel is applied across the group, which means that workers in the group specific model parallel is applied across the group, and workers in the
hold the same copy of the input feature, and demands the same copy of the group hold the same copy of input feature, and requires the same copy of
output. FMoE saves computation by slicing the input in the mp group and the output. For each worker, FMoE only computes the output of a certain
performing all-gather after the MLP computation. slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to. * `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
...@@ -137,7 +84,8 @@ class FMoE(nn.Module): ...@@ -137,7 +84,8 @@ class FMoE(nn.Module):
num_expert=32, num_expert=32,
d_model=1024, d_model=1024,
world_size=1, world_size=1,
mp_group=None, mp_group=None, # being deprecated
slice_group=None,
moe_group=None, moe_group=None,
top_k=2, top_k=2,
gate=NaiveGate, gate=NaiveGate,
...@@ -150,13 +98,18 @@ class FMoE(nn.Module): ...@@ -150,13 +98,18 @@ class FMoE(nn.Module):
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
self.world_size = world_size self.world_size = world_size
self.mp_group = mp_group
if mp_group is None: self.slice_group = slice_group
self.mp_size = 1 if mp_group is not None:
self.mp_rank = 0 print('[Warning] mp_group is being deprecated')
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else: else:
self.mp_size = mp_group.size() self.slice_size = self.slice_group.size()
self.mp_rank = mp_group.rank() self.slice_rank = self.slice_group.rank()
self.top_k = top_k self.top_k = top_k
if type(expert) is list: if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert]) self.experts = nn.ModuleList([e(d_model) for e in expert])
...@@ -168,6 +121,7 @@ class FMoE(nn.Module): ...@@ -168,6 +121,7 @@ class FMoE(nn.Module):
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook self.gate_hook = gate_hook
self.mask = mask self.mask = mask
...@@ -203,7 +157,7 @@ class FMoE(nn.Module): ...@@ -203,7 +157,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
else: else:
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "moe") mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp): def forward(self, inp):
r""" r"""
...@@ -213,8 +167,9 @@ class FMoE(nn.Module): ...@@ -213,8 +167,9 @@ class FMoE(nn.Module):
""" """
if self.world_size > 1: if self.world_size > 1:
ensure_comm(inp, self.moe_group) ensure_comm(inp, self.moe_group)
if self.mp_size > 1: if self.slice_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group) inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
...@@ -249,6 +204,7 @@ class FMoE(nn.Module): ...@@ -249,6 +204,7 @@ class FMoE(nn.Module):
gate_score = gate_score.view(x.shape[0], 1, self.top_k) gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1: if self.slice_size > 1:
x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group) x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x return x
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
...@@ -15,5 +15,6 @@ from .balance import reset_gate_hook ...@@ -15,5 +15,6 @@ from .balance import reset_gate_hook
from .balance import get_balance_profile from .balance import get_balance_profile
from .balance import generate_megatron_gate_hook from .balance import generate_megatron_gate_hook
from .balance import add_balance_log from .balance import add_balance_log
from .balance import patch_forward_step
from .balance import patch_model_provider from .patch import patch_forward_step
from .patch import patch_model_provider
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from fmoe.balance import reset_balance_profile from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {} balance_dict = {}
...@@ -71,63 +70,3 @@ def add_balance_log(model, writer, iteration): ...@@ -71,63 +70,3 @@ def add_balance_log(model, writer, iteration):
balance_dict_tensor[idx].mean().item(), balance_dict_tensor[idx].mean().item(),
iteration, iteration,
) )
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider
r""" r"""
distributed support for Megatron distributed support for Megatron
""" """
import torch
from fmoe.distributed import DistributedGroupedDataParallel from fmoe.distributed import DistributedGroupedDataParallel
_moe_group = None _groups = None
def _set_groups(**kwargs):
global _groups
_groups = kwargs
def set_moe_group(moe_group):
global _moe_group
_moe_group = moe_group
def _init():
from megatron import get_args
from megatron import mpu
args = get_args()
# Create a comm prependicular to the pipeline group as gate group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
gate_group = group
def get_moe_group(): _set_groups(
return _moe_group dp_group=mpu.get_data_parallel_group(),
moe_group=mpu.get_data_parallel_group(),
gate_group=gate_group)
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
...@@ -24,14 +41,9 @@ class DistributedDataParallel(DistributedGroupedDataParallel): ...@@ -24,14 +41,9 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
""" """
def __init__(self, module): def __init__(self, module):
from megatron import mpu if _groups is None:
_init()
super().__init__( super().__init__(module, **_groups)
module,
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group(),
moe_group=_moe_group
)
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
r""" r"""
......
...@@ -10,7 +10,6 @@ import torch.nn.functional as F ...@@ -10,7 +10,6 @@ import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook from .balance import generate_megatron_gate_hook
from .distributed import set_moe_group
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
...@@ -75,33 +74,28 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -75,33 +74,28 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
""" """
def __init__(self, args, mp_group, moe_group, layer_idx): def __init__(self, args, layer_idx, gate=None):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
moe_group = None
else: else:
world_size = args.tensor_model_parallel_size * args.data_parallel_size world_size = args.data_parallel_size
gate = None from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group()
if not args.balance_strategy or args.balance_strategy == "naive": if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
gate = NaiveGate gate = NaiveGate
elif args.balance_strategy == "noisy": elif args.balance_strategy == "noisy":
from fmoe.gates import NoisyGate from fmoe.gates import NoisyGate
gate = NoisyGate gate = NoisyGate
elif args.balance_strategy == "gshard": elif args.balance_strategy == "gshard":
from fmoe.gates import GShardGate from fmoe.gates import GShardGate
gate = GShardGate gate = GShardGate
elif args.balance_strategy == "switch": elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate from fmoe.gates import SwitchGate
gate = SwitchGate gate = SwitchGate
else: elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
...@@ -110,7 +104,6 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -110,7 +104,6 @@ class MegatronMLP(FMoETransformerMLP):
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, world_size=world_size,
mp_group=mp_group,
moe_group=moe_group, moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp", expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook( gate_hook=generate_megatron_gate_hook(
...@@ -139,8 +132,11 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -139,8 +132,11 @@ class MegatronMLP(FMoETransformerMLP):
_megatron_init_method(self.experts.h4toh, rng, std) _megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp): def forward(self, inp):
from megatron import mpu
x = super().forward(inp)
x = mpu.reduce_from_tensor_model_parallel_region(x)
return ( return (
super().forward(inp), x,
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device), torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
) )
...@@ -151,6 +147,7 @@ def fmoefy( ...@@ -151,6 +147,7 @@ def fmoefy(
distributed_experts=True, distributed_experts=True,
hidden_hidden_size=None, hidden_hidden_size=None,
top_k=None, top_k=None,
gate=None,
): ):
r""" r"""
Replace MLP layers in a transformer-based model in Megatron by MoE. Replace MLP layers in a transformer-based model in Megatron by MoE.
...@@ -167,47 +164,28 @@ def fmoefy( ...@@ -167,47 +164,28 @@ def fmoefy(
tensor_model_parall_comm x data_parallel_comm, which is not created. tensor_model_parall_comm x data_parallel_comm, which is not created.
""" """
from megatron import get_args from megatron import get_args
from megatron import mpu
args = get_args() args = get_args()
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if num_experts is not None: if num_experts is not None:
args.num_experts = num_experts args.num_experts = num_experts
assert ( assert (
"num_experts" in args "num_experts" in args
), "num_experts should be specified in arguments or fmoefy function" ), "num_experts should be specified in arguments or fmoefy function"
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None: if top_k is not None:
args.top_k = top_k args.top_k = top_k
elif not hasattr(args, "top_k"): elif not hasattr(args, "top_k"):
args.top_k = 2 args.top_k = 2
# Set distributed_experts to None to use default setting in args args.hidden_hidden_size = hidden_hidden_size
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if hasattr(mpu, 'get_tensor_model_parallel_group'):
mp_group = mpu.get_tensor_model_parallel_group()
else:
# For compatibility to older versions of Megatron-LM
mp_group = mpu.get_model_parallel_group()
if args.pipeline_model_parallel_size == 1:
moe_group = None
else:
# Create a comm prependicular to pipeline group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
moe_group = group
set_moe_group(moe_group)
for idx, l in enumerate(model.language_model.transformer.layers): for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mp_group, moe_group, idx) l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
......
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
import torch
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron.mpu import get_tensor_model_parallel_group
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_tensor_model_parallel_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider, gate=None):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate
)
return fmoefied_model_provider
...@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer. ...@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer.
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from .gates import NaiveGate from .layers import FMoE
from .layers import FMoE, FMoELinear from .linear import FMoELinear
class _Expert(nn.Module): class _Expert(nn.Module):
...@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE): ...@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE):
num_expert=32, num_expert=32,
d_model=1024, d_model=1024,
d_hidden=4096, d_hidden=4096,
world_size=1,
mp_group=None,
moe_group=None,
activation=torch.nn.GELU(), activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm="none", expert_dp_comm="none",
gate_hook=None, expert_rank=0,
mask=None, **kwargs
mask_dict=None,
): ):
super().__init__( super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
num_expert=num_expert,
d_model=d_model,
gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook,
mask=mask,
mask_dict=mask_dict
)
self.experts = _Expert( self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank num_expert, d_model, d_hidden, activation, rank=expert_rank
) )
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
......
...@@ -36,7 +36,7 @@ else: ...@@ -36,7 +36,7 @@ else:
if __name__ == '__main__': if __name__ == '__main__':
setuptools.setup( setuptools.setup(
name='fastmoe', name='fastmoe',
version='0.2.1', version='0.3.0',
description='An efficient Mixture-of-Experts system for PyTorch', description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors), author=', '.join(authors),
author_email='hja20@mails.tsinghua.edu.cn', author_email='hja20@mails.tsinghua.edu.cn',
......
...@@ -342,7 +342,8 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -342,7 +342,8 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
model = MyModule().cuda() model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group) model_ddp = LocalDDP(deepcopy(model),
mp_group=mp_group, dp_group=dp_group, world_group=world_group)
model.set_comm() model.set_comm()
model_ddp.module.set_comm() model_ddp.module.set_comm()
......
...@@ -56,8 +56,9 @@ def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1) ...@@ -56,8 +56,9 @@ def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1)
mask_dict = { mask_dict = {
1: torch.zeros(d_hidden).cuda() 1: torch.zeros(d_hidden).cuda()
} }
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size, model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4,
gate=ConstantGate, mask=mask, mask_dict=mask_dict).cuda() world_size=world_size, gate=ConstantGate, mask=mask,
mask_dict=mask_dict).cuda()
oup = model(inp) oup = model(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment