Commit 4d48209d authored by Sengxian's avatar Sengxian
Browse files

Format using black

parent 527a8cc9
......@@ -142,6 +142,7 @@ disable=print-statement,
arguments-differ,
import-outside-toplevel,
signature-differs,
bad-continuation,
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
......
r'''
r"""
Supportive modules to conduct distributed training
'''
"""
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
......@@ -8,7 +8,7 @@ from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module):
r'''
r"""
A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object.
......@@ -20,36 +20,42 @@ class DistributedGroupedDataParallel(nn.Module):
If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for
shared layers like the gate.
'''
def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
auto_allreduce=False):
assert not auto_allreduce, 'Automatic all-reduce is not implemented yet'
"""
def __init__(
self,
module,
mp_group=None,
dp_group=None,
world_group=None,
auto_allreduce=False,
):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
super().__init__()
self.module = module
self.comms = dict()
if mp_group is not None:
self.comms['mp'] = mp_group
self.comms["mp"] = mp_group
if dp_group is not None:
self.comms['dp'] = dp_group
self.comms["dp"] = dp_group
else:
self.comms['dp'] = get_torch_default_comm()
self.comms["dp"] = get_torch_default_comm()
if world_group is None:
self.comms['world'] = get_torch_default_comm()
self.comms["world"] = get_torch_default_comm()
else:
self.comms['world'] = world_group
self.comms["world"] = world_group
def allreduce_params(no_scale=False, reduce_after=False,
fp32_allreduce=False):
def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, 'dp_comm'):
if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm
else:
dp_comm = 'dp'
dp_comm = "dp"
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
......@@ -81,10 +87,10 @@ class DistributedGroupedDataParallel(nn.Module):
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, 'dp_comm'):
if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm
else:
dp_comm = 'dp'
dp_comm = "dp"
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
......@@ -103,7 +109,7 @@ class DistributedGroupedDataParallel(nn.Module):
d.copy_(s)
def forward(self, *args, **kwargs):
r'''
r"""
Directly call the module's forward function.
'''
"""
return self.module(*args, **kwargs)
......@@ -40,8 +40,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
return (
pos,
......@@ -58,6 +57,7 @@ class MOEScatter(Function):
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(
ctx,
......@@ -107,6 +107,7 @@ class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
(global_output_buf,) = fmoe_cuda.forward(
......@@ -130,6 +131,7 @@ class MOEGather(Function):
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter.
"""
@staticmethod
def forward(
ctx,
......@@ -176,9 +178,10 @@ class MOEGather(Function):
class AllGather(Function):
r'''
r"""
A wrapper for the All-Gather function to support auto-differentiation.
'''
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
......@@ -191,13 +194,14 @@ class AllGather(Function):
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None
return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
class Slice(Function):
r'''
r"""
A wrapper for the Slice function to support auto-differentiation.
'''
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B: int = inp.shape[0]
......
r'''
r"""
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
'''
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroGate(nn.Module):
r'''
r"""
Guide all input samples to gate 0.
'''
"""
def __init__(self, _1, _2, _3, top_k=2):
super().__init__()
self.top_k = top_k
def forward(self, inp):
r'''
r"""
All output to expert 1
'''
idx = torch.zeros(inp.shape[0] * self.top_k,
dtype=torch.int64, device=inp.device)
score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
score = torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
class NaiveGate(nn.Module):
r'''
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
'''
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
r'''
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
......
r'''
r"""
Layers that FMoE provides to users
'''
"""
import torch
import torch.nn as nn
......@@ -11,14 +11,21 @@ from .gates import NaiveGate
class FMoELinear(nn.Module):
r'''
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def __init__(self, num_expert: int, in_feat: int, out_feat: int,
bias: bool = True, rank: int = 0):
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
......@@ -28,12 +35,12 @@ class FMoELinear(nn.Module):
if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
def forward(self, inp, fwd_expert_count):
r'''
r"""
Call MOE function
'''
"""
x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
......@@ -45,8 +52,9 @@ class FMoELinear(nn.Module):
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(self.bias,
fwd_expert_count.to(self.bias.device), dim=0)
bias = torch.repeat_interleave(
self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
......@@ -67,24 +75,27 @@ class FMoELinear(nn.Module):
return x
def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format(
self.num_expert, self.in_feat,
self.out_feat, self.bias is not None, self.rank
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def mark_module_parallel_comm(module, comm):
r'''
r"""
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
'''
"""
for p in module.parameters():
setattr(p, 'dp_comm', comm)
setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r'''
r"""
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
......@@ -94,14 +105,16 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
'''
"""
(
pos, local_expert_count, global_expert_count, fwd_expert_count,
fwd_batch_size
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
world_size
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
)
x = expert_fn(x, fwd_expert_count)
x = MOEGather.apply(
......@@ -111,7 +124,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module):
r'''
r"""
A general moe implementation that supports an arbitrary module as the
expert.
* `num_expert` stands for the number of experts on **each** worker.
......@@ -126,9 +139,18 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
'''
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None):
"""
def __init__(
self,
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None,
top_k=2,
gate=NaiveGate,
expert=None,
):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
......@@ -143,34 +165,33 @@ class FMoE(nn.Module):
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None:
self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count):
r'''
r"""
The default expert function which either calls the experts as a whole
or as separate experts.
'''
"""
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx:base_idx + batch_size]
inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self, expert_dp_comm='none'):
r'''
def mark_parallel_comm(self, expert_dp_comm="none"):
r"""
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
'''
"""
if self.experts is not None:
comm = expert_dp_comm
if isinstance(self.experts, list):
......@@ -178,29 +199,28 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, 'world')
mark_module_parallel_comm(self.gate, "world")
def forward(self, inp):
r'''
r"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
'''
"""
if self.mp_size > 1:
inp = Slice.apply(inp,
self.mp_rank, self.mp_size, self.mp_group)
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
self.num_expert, self.world_size)
x = _fmoe_general_global_forward(
inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
)
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
x = AllGather.apply(x,
self.mp_rank, self.mp_size, self.mp_group)
x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
return x
r'''
r"""
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
"""
import math
import numpy as np
import torch
......@@ -14,27 +14,30 @@ from .distributed import DistributedGroupedDataParallel
class _FakeMegatronMLP(nn.Module):
r'''
r"""
A fake mlp without model parallelism for correctness testing
'''
"""
def __init__(self, args, _):
super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x):
r'''
r"""
Directly use GeLU
'''
"""
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
return x, torch.zeros_like(x)
def _megatron_init_method(self, rng, sigma):
r'''
r"""
Init method based on N(0, sigma).
Copied from Megatron-LM
'''
"""
device = self.weight.device
dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
......@@ -45,12 +48,13 @@ def _megatron_init_method(self, rng, sigma):
with torch.no_grad():
self.bias.zero_()
def _random_init_weight(self, rng):
r'''
r"""
Copied from torch.nn.init.kaiming_uniform_
'''
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
"""
fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
......@@ -66,23 +70,29 @@ def _random_init_weight(self, rng):
class MegatronMLP(FMoETransformerMLP):
r'''
r"""
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
"""
def __init__(self, args, group):
assert (args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size == 0
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts:
world_size = 1
else:
world_size = args.world_size
super().__init__(args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
super().__init__(
args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size,
world_size=world_size,
mp_group=group,
expert_dp_comm="none" if args.distributed_experts else "dp",
)
self.hidden_size = args.hidden_size
if args.distributed_experts:
self.rank = args.rank
......@@ -93,24 +103,31 @@ class MegatronMLP(FMoETransformerMLP):
self.reset_parameters()
def reset_parameters(self):
r'''
r"""
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
'''
"""
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers)
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size,
dtype=inp.dtype, device=inp.device)
return (
super().forward(inp),
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
)
def fmoefy(model, num_experts=None, distributed_experts=True,
hidden_hidden_size=None, top_k=None):
r'''
def fmoefy(
model,
num_experts=None,
distributed_experts=True,
hidden_hidden_size=None,
top_k=None,
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
......@@ -123,24 +140,25 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
Note that pipeline parallel is not supported yet. When distributed experts
are enabled, their communicator should be Megatron's
tensor_model_parall_comm x data_parallel_comm, which is not created.
'''
"""
from megatron import get_args
from megatron import mpu
args = get_args()
if num_experts is not None:
args.num_experts = num_experts
assert (
'num_experts' in args
), 'num_experts should be specified in arguments or fmoefy function'
"num_experts" in args
), "num_experts should be specified in arguments or fmoefy function"
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, 'hidden_hidden_size'):
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, 'top_k'):
elif not hasattr(args, "top_k"):
args.top_k = 2
# Set distributed_experts to None to use default setting in args
......@@ -153,33 +171,35 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
class DistributedDataParallel(DistributedGroupedDataParallel):
r'''
r"""
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
'''
"""
def __init__(self, module):
from megatron import mpu
super().__init__(
module,
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group()
dp_group=mpu.get_data_parallel_group(),
)
def state_dict(self, *args, **kwargs):
r'''
r"""
Keep consitency with Megatron
'''
"""
return self.module.state_dict(*args, **kwargs)
def state_dict_for_save_checkpoint(self, *args, **kwargs):
r'''
r"""
Keep consitency with Megatron
'''
"""
return self.module.state_dict_for_save_checkpoint(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
r'''
r"""
Keep consitency with Megatron
'''
"""
return self.module.load_state_dict(*args, **kwargs)
r'''
r"""
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
"""
import torch
import torch.nn as nn
from .gates import NaiveGate
......@@ -8,23 +8,22 @@ from .layers import FMoE, FMoELinear
class _Expert(nn.Module):
r'''
r"""
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
'''
"""
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden,
bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
bias=True, rank=rank)
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r'''
r"""
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
'''
"""
x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
......@@ -32,11 +31,12 @@ class _Expert(nn.Module):
class FMoETransformerMLP(FMoE):
r'''
r"""
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
'''
"""
def __init__(
self,
num_expert=32,
......@@ -47,19 +47,26 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm='none'
expert_dp_comm="none",
):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group)
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
)
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
r'''
r"""
This module wraps up the FMoE module with reshape, residual and layer
normalization.
'''
"""
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
output = super().forward(inp)
......
r'''
r"""
Utils to play with PyTorch.
'''
"""
import torch.distributed as dist
# pylint: disable=broad-except
# pylint: disable=protected-access
def get_torch_default_comm():
r'''
r"""
The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
......@@ -15,7 +15,7 @@ def get_torch_default_comm():
communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch.
'''
"""
try:
comm = dist.distributed_c10d._get_default_group()
return comm
......@@ -27,4 +27,4 @@ def get_torch_default_comm():
return comm
except Exception as _:
pass
raise RuntimeError('Unsupported PyTorch version')
raise RuntimeError("Unsupported PyTorch version")
......@@ -10,18 +10,27 @@ import os
rank = None
world_size = None
dev_name_default = 'cuda:0'
dev_name_default = "cuda:0"
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=1, mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate, top_k=1, pre_lnorm=False):
assert world_size == 1, 'Distributed brute force is not supported'
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=1,
pre_lnorm=False,
):
assert world_size == 1, "Distributed brute force is not supported"
super().__init__()
self.mlp = BruteForceMoELinear(activation, num_expert, d_model,
d_hidden, 1, top_k)
self.mlp = BruteForceMoELinear(
activation, num_expert, d_model, d_hidden, 1, top_k
)
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm
......@@ -43,20 +52,32 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if rank == 0:
print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
.format(MOELayer.__name__, batch_size, in_feat, hidden_feat,
world_size, num_expert, top_k))
print(
"Performance test of {} mm size {} {}x{} experts {}x{} topk {}".format(
MOELayer.__name__,
batch_size,
in_feat,
hidden_feat,
world_size,
num_expert,
top_k,
)
)
if world_size > 1:
dev_name = 'cuda'
dev_name = "cuda"
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name)
inp.requires_grad = True
moe = MOELayer(num_expert=num_expert,
d_model=in_feat, d_hidden=hidden_feat,
world_size=world_size, top_k=top_k).cuda(dev_name)
moe = MOELayer(
num_expert=num_expert,
d_model=in_feat,
d_hidden=hidden_feat,
world_size=world_size,
top_k=top_k,
).cuda(dev_name)
moe.train()
# warm up
......@@ -64,10 +85,10 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
_ = moe(inp)
n_runs = 16
tott = 0.
backt = 0.
maxt = 0.
sqtot = 0.
tott = 0.0
backt = 0.0
maxt = 0.0
sqtot = 0.0
for i in range(n_runs):
ts = time.time()
o = moe(inp)
......@@ -80,36 +101,48 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
bte = time.time()
tott += te - ts
sqtot += (te - ts)**2
sqtot += (te - ts) ** 2
maxt = max(maxt, te - ts)
backt += bte - bts
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs, gflops))
if __name__ == '__main__':
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0')
if int(os.environ['WORLD_SIZE']) > 1:
torch.distributed.init_process_group(backend='nccl')
gflops = (
2e-9
* n_runs
* (
in_feat * hidden_feat * batch_size * top_k * 2
+ batch_size * in_feat * num_expert
)
/ tott
)
print(
"Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs".format(
tott * 1e3 / n_runs,
maxt * 1e3,
(sqtot / n_runs - (tott / n_runs) ** 2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs,
gflops,
)
)
if __name__ == "__main__":
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get(
"OMPI_COMM_WORLD_LOCAL_RANK", "0"
)
if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
batch_size = int(os.environ.get('BATCH_SIZE', '4096'))
d_model = int(os.environ.get('D_MODEL', '1024'))
d_hidden = int(os.environ.get('D_HIDDEN', '4096'))
num_expert = int(os.environ.get('NUM_EXPERT', '64'))
top_k = int(os.environ.get('TOP_K', '2'))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model,
d_hidden, num_expert, top_k)
batch_size = int(os.environ.get("BATCH_SIZE", "4096"))
d_model = int(os.environ.get("D_MODEL", "1024"))
d_hidden = int(os.environ.get("D_HIDDEN", "4096"))
num_expert = int(os.environ.get("NUM_EXPERT", "64"))
top_k = int(os.environ.get("TOP_K", "2"))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model, d_hidden, num_expert, top_k)
if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert,
top_k)
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, top_k)
......@@ -20,24 +20,19 @@ class BruteForceMoELinear(nn.Module):
self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model)
)
self.bias_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden)
)
self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden))
self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden)
)
self.bias_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model)
)
self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model))
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype,
device=inp.device)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
for i in range(self.weight_htoh4.shape[0]):
idx = (gate_idx == i)
idx = gate_idx == i
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
......@@ -45,8 +40,9 @@ class BruteForceMoELinear(nn.Module):
x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x
......
......@@ -18,7 +18,7 @@ class MyMoE(FMoE):
gate=NaiveGate,
world_size=1,
mp_group=None,
top_k=top_k
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
......@@ -46,5 +46,5 @@ def test_fmoe_dp(
output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == '__main__':
if __name__ == "__main__":
test_fmoe_dp(4, 2, 4, 16, 32)
......@@ -68,7 +68,6 @@ class MyMoE(FMoE):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
......@@ -77,8 +76,8 @@ class MyMoE(FMoE):
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.)
_megatron_init_method(self.experts.h4toh, rng, 1.)
_megatron_init_method(self.experts.htoh4, rng, 1.0)
_megatron_init_method(self.experts.h4toh, rng, 1.0)
@pytest.mark.parametrize("num_expert", [4, 8])
......@@ -152,8 +151,22 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
moe_out_list = moe_out, moe_grad_in, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
moe_out_list = (
moe_out,
moe_grad_in,
moe.experts.htoh4.weight.grad,
moe.experts.h4toh.weight.grad,
moe.experts.htoh4.bias.grad,
moe.experts.h4toh.bias.grad,
)
raw_out_list = (
raw_out,
raw_grad_in,
moe_raw.weight_htoh4.grad,
moe_raw.weight_h4toh.grad,
moe_raw.bias_htoh4.grad,
moe_raw.bias_h4toh.grad,
)
if world_size > 1:
_, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
......@@ -176,7 +189,14 @@ def test_fmoe_linear(
)
raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
names = [
"output",
"input grad",
"htoh4 weight grad",
"h4toh weight grad",
"htoh4 bias grad",
"h4toh bias grad",
]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment