Unverified Commit 4a9ef7fd authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #81 from laekov/realmp

Model parallelism of single experts
parents 368c8e41 d30585dc
......@@ -2,6 +2,7 @@ r"""
The fmoe package contains MoE Layers only.
"""
from .layers import FMoELinear, FMoE
from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
......@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module):
def __init__(
self,
module,
mp_group=None,
dp_group=None,
moe_group=None,
world_group=None,
auto_allreduce=False,
**kwargs
):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
......@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module):
self.module = module
self.comms = dict()
if mp_group is not None:
self.comms["mp"] = mp_group
if dp_group is not None:
self.comms["dp"] = dp_group
else:
self.comms["dp"] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None:
self.comms["world"] = get_torch_default_comm()
else:
self.comms["world"] = world_group
for k in kwargs:
if k.endswith('_group'):
self.comms[k[:-6]] = kwargs[k]
for k in ['dp', 'gate', 'moe', 'world']:
if k not in self.comms:
self.comms[k] = get_torch_default_comm()
def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False):
......
......@@ -132,34 +132,6 @@ class MOEScatter(Function):
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function):
r"""
Gather output samples from contiguous alone experts back to [batch x
......
r"""
Layers that FMoE provides to users
FMoE core layer
"""
import torch
import torch.nn as nn
import math
from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice
from .gates import NaiveGate
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm):
r"""
......@@ -121,11 +67,12 @@ class FMoE(nn.Module):
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `mp_group` can be a torch's communication group, indicating that model
parallel is applied across the group, which means that workers in the group
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `slice_group` can be a torch's communication group, indicating that
specific model parallel is applied across the group, and workers in the
group hold the same copy of input feature, and requires the same copy of
the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
......@@ -137,7 +84,8 @@ class FMoE(nn.Module):
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None,
mp_group=None, # being deprecated
slice_group=None,
moe_group=None,
top_k=2,
gate=NaiveGate,
......@@ -150,13 +98,18 @@ class FMoE(nn.Module):
self.num_expert = num_expert
self.d_model = d_model
self.world_size = world_size
self.mp_group = mp_group
if mp_group is None:
self.mp_size = 1
self.mp_rank = 0
self.slice_group = slice_group
if mp_group is not None:
print('[Warning] mp_group is being deprecated')
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else:
self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank()
self.slice_size = self.slice_group.size()
self.slice_rank = self.slice_group.rank()
self.top_k = top_k
if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert])
......@@ -168,6 +121,7 @@ class FMoE(nn.Module):
self.experts_fused = False
else:
self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook
self.mask = mask
......@@ -203,7 +157,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "moe")
mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp):
r"""
......@@ -213,8 +167,9 @@ class FMoE(nn.Module):
"""
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp)
......@@ -249,6 +204,7 @@ class FMoE(nn.Module):
gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
......@@ -15,5 +15,6 @@ from .balance import reset_gate_hook
from .balance import get_balance_profile
from .balance import generate_megatron_gate_hook
from .balance import add_balance_log
from .balance import patch_forward_step
from .balance import patch_model_provider
from .patch import patch_forward_step
from .patch import patch_model_provider
......@@ -5,7 +5,6 @@ import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {}
......@@ -71,63 +70,3 @@ def add_balance_log(model, writer, iteration):
balance_dict_tensor[idx].mean().item(),
iteration,
)
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider
r"""
distributed support for Megatron
"""
import torch
from fmoe.distributed import DistributedGroupedDataParallel
_moe_group = None
_groups = None
def _set_groups(**kwargs):
global _groups
_groups = kwargs
def set_moe_group(moe_group):
global _moe_group
_moe_group = moe_group
def _init():
from megatron import get_args
from megatron import mpu
args = get_args()
# Create a comm prependicular to the pipeline group as gate group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
gate_group = group
def get_moe_group():
return _moe_group
_set_groups(
dp_group=mpu.get_data_parallel_group(),
moe_group=mpu.get_data_parallel_group(),
gate_group=gate_group)
class DistributedDataParallel(DistributedGroupedDataParallel):
......@@ -24,14 +41,9 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
"""
def __init__(self, module):
from megatron import mpu
super().__init__(
module,
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group(),
moe_group=_moe_group
)
if _groups is None:
_init()
super().__init__(module, **_groups)
def state_dict(self, *args, **kwargs):
r"""
......
......@@ -10,7 +10,6 @@ import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook
from .distributed import set_moe_group
class _FakeMegatronMLP(nn.Module):
......@@ -75,33 +74,28 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
"""
def __init__(self, args, mp_group, moe_group, layer_idx):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
def __init__(self, args, layer_idx, gate=None):
if not args.distributed_experts:
world_size = 1
moe_group = None
else:
world_size = args.tensor_model_parallel_size * args.data_parallel_size
gate = None
world_size = args.data_parallel_size
from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group()
if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate
gate = NaiveGate
elif args.balance_strategy == "noisy":
from fmoe.gates import NoisyGate
gate = NoisyGate
elif args.balance_strategy == "gshard":
from fmoe.gates import GShardGate
gate = GShardGate
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
else:
elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__(
......@@ -110,7 +104,6 @@ class MegatronMLP(FMoETransformerMLP):
d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook(
......@@ -139,8 +132,11 @@ class MegatronMLP(FMoETransformerMLP):
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
from megatron import mpu
x = super().forward(inp)
x = mpu.reduce_from_tensor_model_parallel_region(x)
return (
super().forward(inp),
x,
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
)
......@@ -151,6 +147,7 @@ def fmoefy(
distributed_experts=True,
hidden_hidden_size=None,
top_k=None,
gate=None,
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
......@@ -167,47 +164,28 @@ def fmoefy(
tensor_model_parall_comm x data_parallel_comm, which is not created.
"""
from megatron import get_args
from megatron import mpu
args = get_args()
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if num_experts is not None:
args.num_experts = num_experts
assert (
"num_experts" in args
), "num_experts should be specified in arguments or fmoefy function"
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, "top_k"):
args.top_k = 2
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
args.hidden_hidden_size = hidden_hidden_size
if hasattr(mpu, 'get_tensor_model_parallel_group'):
mp_group = mpu.get_tensor_model_parallel_group()
else:
# For compatibility to older versions of Megatron-LM
mp_group = mpu.get_model_parallel_group()
if args.pipeline_model_parallel_size == 1:
moe_group = None
else:
# Create a comm prependicular to pipeline group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
moe_group = group
set_moe_group(moe_group)
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mp_group, moe_group, idx)
l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
......
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
import torch
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron.mpu import get_tensor_model_parallel_group
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_tensor_model_parallel_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider, gate=None):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate
)
return fmoefied_model_provider
......@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer.
"""
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear
from .layers import FMoE
from .linear import FMoELinear
class _Expert(nn.Module):
......@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE):
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
moe_group=None,
activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm="none",
gate_hook=None,
mask=None,
mask_dict=None,
expert_rank=0,
**kwargs
):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook,
mask=mask,
mask_dict=mask_dict
)
super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
num_expert, d_model, d_hidden, activation, rank=expert_rank
)
self.mark_parallel_comm(expert_dp_comm)
......
......@@ -36,7 +36,7 @@ else:
if __name__ == '__main__':
setuptools.setup(
name='fastmoe',
version='0.2.1',
version='0.3.0',
description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors),
author_email='hja20@mails.tsinghua.edu.cn',
......
......@@ -342,7 +342,8 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
torch.cuda.manual_seed(42 + rank)
model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
model_ddp = LocalDDP(deepcopy(model),
mp_group=mp_group, dp_group=dp_group, world_group=world_group)
model.set_comm()
model_ddp.module.set_comm()
......
......@@ -56,8 +56,9 @@ def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1)
mask_dict = {
1: torch.zeros(d_hidden).cuda()
}
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate, mask=mask, mask_dict=mask_dict).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4,
world_size=world_size, gate=ConstantGate, mask=mask,
mask_dict=mask_dict).cuda()
oup = model(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment