"website/vscode:/vscode.git/clone" did not exist on "16d0aa82c1bbf3788571c651f5149f3c4e91a47a"
Commit 4d48209d authored by Sengxian's avatar Sengxian
Browse files

Format using black

parent 527a8cc9
...@@ -142,6 +142,7 @@ disable=print-statement, ...@@ -142,6 +142,7 @@ disable=print-statement,
arguments-differ, arguments-differ,
import-outside-toplevel, import-outside-toplevel,
signature-differs, signature-differs,
bad-continuation,
# Enable the message, report, category or checker with the given id(s). You can # Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option
......
r''' r"""
Supportive modules to conduct distributed training Supportive modules to conduct distributed training
''' """
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -8,7 +8,7 @@ from .utils import get_torch_default_comm ...@@ -8,7 +8,7 @@ from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module): class DistributedGroupedDataParallel(nn.Module):
r''' r"""
A customized DDP module to support different all-reduce regions in the A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object. weight object.
...@@ -20,36 +20,42 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -20,36 +20,42 @@ class DistributedGroupedDataParallel(nn.Module):
If it is set to `world`, the gradients is synchronized across all workers, If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for regardless their model or data parallel group. This is extremely useful for
shared layers like the gate. shared layers like the gate.
''' """
def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
auto_allreduce=False): def __init__(
assert not auto_allreduce, 'Automatic all-reduce is not implemented yet' self,
module,
mp_group=None,
dp_group=None,
world_group=None,
auto_allreduce=False,
):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
super().__init__() super().__init__()
self.module = module self.module = module
self.comms = dict() self.comms = dict()
if mp_group is not None: if mp_group is not None:
self.comms['mp'] = mp_group self.comms["mp"] = mp_group
if dp_group is not None: if dp_group is not None:
self.comms['dp'] = dp_group self.comms["dp"] = dp_group
else: else:
self.comms['dp'] = get_torch_default_comm() self.comms["dp"] = get_torch_default_comm()
if world_group is None: if world_group is None:
self.comms['world'] = get_torch_default_comm() self.comms["world"] = get_torch_default_comm()
else: else:
self.comms['world'] = world_group self.comms["world"] = world_group
def allreduce_params(no_scale=False, reduce_after=False, def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False):
fp32_allreduce=False):
groups = dict() groups = dict()
for p in self.module.parameters(): for p in self.module.parameters():
if not p.requires_grad or p.grad is None: if not p.requires_grad or p.grad is None:
continue continue
if hasattr(p, 'dp_comm'): if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm dp_comm = p.dp_comm
else: else:
dp_comm = 'dp' dp_comm = "dp"
group_key = (dp_comm, p.dtype) group_key = (dp_comm, p.dtype)
if group_key not in groups: if group_key not in groups:
groups[group_key] = [p] groups[group_key] = [p]
...@@ -81,10 +87,10 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -81,10 +87,10 @@ class DistributedGroupedDataParallel(nn.Module):
for p in self.module.parameters(): for p in self.module.parameters():
if not p.requires_grad or p.grad is None: if not p.requires_grad or p.grad is None:
continue continue
if hasattr(p, 'dp_comm'): if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm dp_comm = p.dp_comm
else: else:
dp_comm = 'dp' dp_comm = "dp"
group_key = (dp_comm, p.dtype) group_key = (dp_comm, p.dtype)
if group_key not in groups: if group_key not in groups:
groups[group_key] = [p] groups[group_key] = [p]
...@@ -103,7 +109,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -103,7 +109,7 @@ class DistributedGroupedDataParallel(nn.Module):
d.copy_(s) d.copy_(s)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
r''' r"""
Directly call the module's forward function. Directly call the module's forward function.
''' """
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
...@@ -40,8 +40,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -40,8 +40,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
) )
else: else:
global_expert_count = local_expert_count global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
return ( return (
pos, pos,
...@@ -58,6 +57,7 @@ class MOEScatter(Function): ...@@ -58,6 +57,7 @@ class MOEScatter(Function):
If `world_size` is greater than 1, the samples will first be locally If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers. scattered, and then exchanged across workers.
""" """
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
...@@ -107,6 +107,7 @@ class MOELinear(Function): ...@@ -107,6 +107,7 @@ class MOELinear(Function):
r""" r"""
Computes linear operators within one GPU on different experts simutaneously. Computes linear operators within one GPU on different experts simutaneously.
""" """
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, weight, fwd_expert_count):
(global_output_buf,) = fmoe_cuda.forward( (global_output_buf,) = fmoe_cuda.forward(
...@@ -130,6 +131,7 @@ class MOEGather(Function): ...@@ -130,6 +131,7 @@ class MOEGather(Function):
Gather output samples from contiguous alone experts back to [batch x Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter. sequences]. Works symmetrically with MOEScatter.
""" """
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
...@@ -176,9 +178,10 @@ class MOEGather(Function): ...@@ -176,9 +178,10 @@ class MOEGather(Function):
class AllGather(Function): class AllGather(Function):
r''' r"""
A wrapper for the All-Gather function to support auto-differentiation. A wrapper for the All-Gather function to support auto-differentiation.
''' """
@staticmethod @staticmethod
def forward(ctx, inp, rank, world_size, group): def forward(ctx, inp, rank, world_size, group):
tensor_list = [torch.empty_like(inp) for _ in range(world_size)] tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
...@@ -191,13 +194,14 @@ class AllGather(Function): ...@@ -191,13 +194,14 @@ class AllGather(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
rank, dim0 = ctx.args rank, dim0 = ctx.args
return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
class Slice(Function): class Slice(Function):
r''' r"""
A wrapper for the Slice function to support auto-differentiation. A wrapper for the Slice function to support auto-differentiation.
''' """
@staticmethod @staticmethod
def forward(ctx, inp, rank, world_size, group): def forward(ctx, inp, rank, world_size, group):
B: int = inp.shape[0] B: int = inp.shape[0]
......
r''' r"""
Different implementations of the Gate are located here. Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate. The `NaiveGate` is the reference to implement any other gate.
''' """
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class ZeroGate(nn.Module): class ZeroGate(nn.Module):
r''' r"""
Guide all input samples to gate 0. Guide all input samples to gate 0.
''' """
def __init__(self, _1, _2, _3, top_k=2): def __init__(self, _1, _2, _3, top_k=2):
super().__init__() super().__init__()
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r''' r"""
All output to expert 1 All output to expert 1
''' """
idx = torch.zeros(inp.shape[0] * self.top_k, idx = torch.zeros(
dtype=torch.int64, device=inp.device) inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
score = torch.ones(inp.shape[0] * self.top_k, )
device=inp.device) / self.top_k score = torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k) return idx, score.reshape(-1, 1, self.top_k)
class NaiveGate(nn.Module): class NaiveGate(nn.Module):
r''' r"""
A naive gate implementation that defines the standard behavior of the gate A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to. which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent Both the indecies and the score, or confidence, are output to the parent
module. module.
The load-balance strategies are also designed to be implemented within the The load-balance strategies are also designed to be implemented within the
`Gate` module. `Gate` module.
''' """
def __init__(self, d_model, num_expert, world_size, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__() super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size) self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r''' r"""
The naive implementation simply calculates the top-k of a linear layer's The naive implementation simply calculates the top-k of a linear layer's
output. output.
''' """
gate = self.gate(inp) gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk( gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False gate, k=self.top_k, dim=-1, largest=True, sorted=False
......
r''' r"""
Layers that FMoE provides to users Layers that FMoE provides to users
''' """
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,14 +11,21 @@ from .gates import NaiveGate ...@@ -11,14 +11,21 @@ from .gates import NaiveGate
class FMoELinear(nn.Module): class FMoELinear(nn.Module):
r''' r"""
A linear layer that contains multiple experts. A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance. performed in parallel to increase the performance.
The FMoELinear module provides such function. The FMoELinear module provides such function.
''' """
def __init__(self, num_expert: int, in_feat: int, out_feat: int,
bias: bool = True, rank: int = 0): def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
...@@ -28,12 +35,12 @@ class FMoELinear(nn.Module): ...@@ -28,12 +35,12 @@ class FMoELinear(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat)) self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r"""
Call MOE function Call MOE function
''' """
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias is not None: if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical # TODO: torch.repeat_interleave seems have numerical
...@@ -45,8 +52,9 @@ class FMoELinear(nn.Module): ...@@ -45,8 +52,9 @@ class FMoELinear(nn.Module):
# like MOELinear.apply(x, weight, bias, count) # like MOELinear.apply(x, weight, bias, count)
# Solution 1 # Solution 1
bias = torch.repeat_interleave(self.bias, bias = torch.repeat_interleave(
fwd_expert_count.to(self.bias.device), dim=0) self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2 # Solution 2
# bias_idx = torch.arange(self.num_expert)\ # bias_idx = torch.arange(self.num_expert)\
...@@ -67,24 +75,27 @@ class FMoELinear(nn.Module): ...@@ -67,24 +75,27 @@ class FMoELinear(nn.Module):
return x return x
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \ return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format( out_features={}, bias={}, rank={}".format(
self.num_expert, self.in_feat, self.num_expert,
self.out_feat, self.bias is not None, self.rank self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
) )
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r''' r"""
Mark all parameters in `module` as doing data parallel in `comm`, where Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`. `comm` may be one of `'world', 'dp', 'none'`.
''' """
for p in module.parameters(): for p in module.parameters():
setattr(p, 'dp_comm', comm) setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r''' r"""
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
* Count the number of tokens from each worker to each expert. * Count the number of tokens from each worker to each expert.
...@@ -94,14 +105,16 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -94,14 +105,16 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
* Gather the output features of experts back, and reorder them as sentences. * Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this Intermediate results like expert counts are hidden from users by this
function. function.
''' """
( (
pos, local_expert_count, global_expert_count, fwd_expert_count, pos,
fwd_batch_size local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size) ) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply( x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
world_size
) )
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
x = MOEGather.apply( x = MOEGather.apply(
...@@ -111,7 +124,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -111,7 +124,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module): class FMoE(nn.Module):
r''' r"""
A general moe implementation that supports an arbitrary module as the A general moe implementation that supports an arbitrary module as the
expert. expert.
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
...@@ -126,9 +139,18 @@ class FMoE(nn.Module): ...@@ -126,9 +139,18 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules. `num_expert` expert modules.
''' """
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None): def __init__(
self,
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None,
top_k=2,
gate=NaiveGate,
expert=None,
):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
...@@ -143,34 +165,33 @@ class FMoE(nn.Module): ...@@ -143,34 +165,33 @@ class FMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None: if expert is not None:
self.experts = nn.ModuleList([expert(d_model) self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
for _ in range(num_expert)])
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
r''' r"""
The default expert function which either calls the experts as a whole The default expert function which either calls the experts as a whole
or as separate experts. or as separate experts.
''' """
if self.experts_fused: if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
outputs = [] outputs = []
base_idx = 0 base_idx = 0
for i in range(self.num_expert): for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item() batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx:base_idx + batch_size] inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice)) outputs.append(self.experts[i](inp_slice))
base_idx += batch_size base_idx += batch_size
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
def mark_parallel_comm(self, expert_dp_comm='none'): def mark_parallel_comm(self, expert_dp_comm="none"):
r''' r"""
Automatically mark the data parallel comms of the parameters within the Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function module. This can be typically called at the end of the __init__ function
in child classes. in child classes.
''' """
if self.experts is not None: if self.experts is not None:
comm = expert_dp_comm comm = expert_dp_comm
if isinstance(self.experts, list): if isinstance(self.experts, list):
...@@ -178,29 +199,28 @@ class FMoE(nn.Module): ...@@ -178,29 +199,28 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
else: else:
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, 'world') mark_module_parallel_comm(self.gate, "world")
def forward(self, inp): def forward(self, inp):
r''' r"""
The FMoE module first computes gate output, and then conduct MoE forward The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
''' """
if self.mp_size > 1: if self.mp_size > 1:
inp = Slice.apply(inp, inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model # to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0) inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn, x = _fmoe_general_global_forward(
self.num_expert, self.world_size) inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
)
# to: (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model) x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model # to: (BxL) x d_model
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1: if self.mp_size > 1:
x = AllGather.apply(x, x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
self.mp_rank, self.mp_size, self.mp_group)
return x return x
r''' r"""
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
''' """
import math import math
import numpy as np import numpy as np
import torch import torch
...@@ -14,27 +14,30 @@ from .distributed import DistributedGroupedDataParallel ...@@ -14,27 +14,30 @@ from .distributed import DistributedGroupedDataParallel
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
r''' r"""
A fake mlp without model parallelism for correctness testing A fake mlp without model parallelism for correctness testing
''' """
def __init__(self, args, _): def __init__(self, args, _):
super().__init__() super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size) self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size) self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x): def forward(self, x):
r''' r"""
Directly use GeLU Directly use GeLU
''' """
x = self.fc1(x) x = self.fc1(x)
x = F.gelu(x) x = F.gelu(x)
x = self.fc2(x) x = self.fc2(x)
return x, torch.zeros_like(x) return x, torch.zeros_like(x)
def _megatron_init_method(self, rng, sigma): def _megatron_init_method(self, rng, sigma):
r''' r"""
Init method based on N(0, sigma). Init method based on N(0, sigma).
Copied from Megatron-LM Copied from Megatron-LM
''' """
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size())) weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
...@@ -45,12 +48,13 @@ def _megatron_init_method(self, rng, sigma): ...@@ -45,12 +48,13 @@ def _megatron_init_method(self, rng, sigma):
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
def _random_init_weight(self, rng): def _random_init_weight(self, rng):
r''' r"""
Copied from torch.nn.init.kaiming_uniform_ Copied from torch.nn.init.kaiming_uniform_
''' """
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in') fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5)) gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std bound = math.sqrt(3.0) * std
device = self.weight.device device = self.weight.device
...@@ -66,23 +70,29 @@ def _random_init_weight(self, rng): ...@@ -66,23 +70,29 @@ def _random_init_weight(self, rng):
class MegatronMLP(FMoETransformerMLP): class MegatronMLP(FMoETransformerMLP):
r''' r"""
Make the FMoETransformerMLP layer that distributes experts across Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
''' """
def __init__(self, args, group): def __init__(self, args, group):
assert (args.seq_length * args.micro_batch_size assert (
% args.tensor_model_parallel_size == 0 args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
else: else:
world_size = args.world_size world_size = args.world_size
super().__init__(args.num_experts, super().__init__(
args.num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size, d_model=args.hidden_size,
world_size=world_size, mp_group=group, d_hidden=args.hidden_hidden_size,
expert_dp_comm='none' if args.distributed_experts else 'dp') world_size=world_size,
mp_group=group,
expert_dp_comm="none" if args.distributed_experts else "dp",
)
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
if args.distributed_experts: if args.distributed_experts:
self.rank = args.rank self.rank = args.rank
...@@ -93,24 +103,31 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -93,24 +103,31 @@ class MegatronMLP(FMoETransformerMLP):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
r''' r"""
Initialize the weight as linear layers. Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used. additional numpy rng is used.
''' """
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_megatron_init_method(self.experts.htoh4, rng, self.sigma) _megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers) std = self.sigma / math.sqrt(2.0 * self.num_layers)
_megatron_init_method(self.experts.h4toh, rng, std) _megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp): def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size, return (
dtype=inp.dtype, device=inp.device) super().forward(inp),
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
)
def fmoefy(model, num_experts=None, distributed_experts=True, def fmoefy(
hidden_hidden_size=None, top_k=None): model,
r''' num_experts=None,
distributed_experts=True,
hidden_hidden_size=None,
top_k=None,
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE. Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has * `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an `model.language_model.transformer.layers` as transformer layers, which is an
...@@ -123,24 +140,25 @@ def fmoefy(model, num_experts=None, distributed_experts=True, ...@@ -123,24 +140,25 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
Note that pipeline parallel is not supported yet. When distributed experts Note that pipeline parallel is not supported yet. When distributed experts
are enabled, their communicator should be Megatron's are enabled, their communicator should be Megatron's
tensor_model_parall_comm x data_parallel_comm, which is not created. tensor_model_parall_comm x data_parallel_comm, which is not created.
''' """
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
args = get_args() args = get_args()
if num_experts is not None: if num_experts is not None:
args.num_experts = num_experts args.num_experts = num_experts
assert ( assert (
'num_experts' in args "num_experts" in args
), 'num_experts should be specified in arguments or fmoefy function' ), "num_experts should be specified in arguments or fmoefy function"
if hidden_hidden_size is not None: if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, 'hidden_hidden_size'): elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4 args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None: if top_k is not None:
args.top_k = top_k args.top_k = top_k
elif not hasattr(args, 'top_k'): elif not hasattr(args, "top_k"):
args.top_k = 2 args.top_k = 2
# Set distributed_experts to None to use default setting in args # Set distributed_experts to None to use default setting in args
...@@ -153,33 +171,35 @@ def fmoefy(model, num_experts=None, distributed_experts=True, ...@@ -153,33 +171,35 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
r''' r"""
A wrapper that is used to replace the DDP module provided by Megatron, which A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE. Fast MoE.
''' """
def __init__(self, module): def __init__(self, module):
from megatron import mpu from megatron import mpu
super().__init__( super().__init__(
module, module,
mp_group=mpu.get_model_parallel_group(), mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group() dp_group=mpu.get_data_parallel_group(),
) )
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
r''' r"""
Keep consitency with Megatron Keep consitency with Megatron
''' """
return self.module.state_dict(*args, **kwargs) return self.module.state_dict(*args, **kwargs)
def state_dict_for_save_checkpoint(self, *args, **kwargs): def state_dict_for_save_checkpoint(self, *args, **kwargs):
r''' r"""
Keep consitency with Megatron Keep consitency with Megatron
''' """
return self.module.state_dict_for_save_checkpoint(*args, **kwargs) return self.module.state_dict_for_save_checkpoint(*args, **kwargs)
def load_state_dict(self, *args, **kwargs): def load_state_dict(self, *args, **kwargs):
r''' r"""
Keep consitency with Megatron Keep consitency with Megatron
''' """
return self.module.load_state_dict(*args, **kwargs) return self.module.load_state_dict(*args, **kwargs)
r''' r"""
Adaption to act as the MLP layer using an MoE MLP layer in transformer. Adaption to act as the MLP layer using an MoE MLP layer in transformer.
''' """
import torch import torch
import torch.nn as nn import torch.nn as nn
from .gates import NaiveGate from .gates import NaiveGate
...@@ -8,23 +8,22 @@ from .layers import FMoE, FMoELinear ...@@ -8,23 +8,22 @@ from .layers import FMoE, FMoELinear
class _Expert(nn.Module): class _Expert(nn.Module):
r''' r"""
An expert using 2 FMoELinear modules to speed up the computation of experts An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker. within one worker.
''' """
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
bias=True, rank=rank) self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
bias=True, rank=rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r"""
First expand input to 4h (the hidden size is variable, but is called h4 First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h. for convenience). Then perform activation. Finally shirink back to h.
''' """
x = self.htoh4(inp, fwd_expert_count) x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x) x = self.activation(x)
x = self.h4toh(x, fwd_expert_count) x = self.h4toh(x, fwd_expert_count)
...@@ -32,11 +31,12 @@ class _Expert(nn.Module): ...@@ -32,11 +31,12 @@ class _Expert(nn.Module):
class FMoETransformerMLP(FMoE): class FMoETransformerMLP(FMoE):
r''' r"""
A complete MoE MLP module in a Transformer block. A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert. * `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer. * `d_hidden` is the dimension of the MLP layer.
''' """
def __init__( def __init__(
self, self,
num_expert=32, num_expert=32,
...@@ -47,19 +47,26 @@ class FMoETransformerMLP(FMoE): ...@@ -47,19 +47,26 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.GELU(), activation=torch.nn.GELU(),
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
expert_dp_comm='none' expert_dp_comm="none",
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(
top_k=top_k, world_size=world_size, mp_group=mp_group) num_expert=num_expert,
self.experts = _Expert(num_expert, d_model, d_hidden, activation, d_model=d_model,
rank=self.mp_rank) gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
)
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
r''' r"""
This module wraps up the FMoE module with reshape, residual and layer This module wraps up the FMoE module with reshape, residual and layer
normalization. normalization.
''' """
original_shape = inp.shape original_shape = inp.shape
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
output = super().forward(inp) output = super().forward(inp)
......
r''' r"""
Utils to play with PyTorch. Utils to play with PyTorch.
''' """
import torch.distributed as dist import torch.distributed as dist
# pylint: disable=broad-except # pylint: disable=broad-except
# pylint: disable=protected-access # pylint: disable=protected-access
def get_torch_default_comm(): def get_torch_default_comm():
r''' r"""
The NCCL communicator is needed so that Fast MoE can perform customized The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL` available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
...@@ -15,7 +15,7 @@ def get_torch_default_comm(): ...@@ -15,7 +15,7 @@ def get_torch_default_comm():
communicator out from the object. As PyTorch's private interface varies from communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch. compatible with various versions of PyTorch.
''' """
try: try:
comm = dist.distributed_c10d._get_default_group() comm = dist.distributed_c10d._get_default_group()
return comm return comm
...@@ -27,4 +27,4 @@ def get_torch_default_comm(): ...@@ -27,4 +27,4 @@ def get_torch_default_comm():
return comm return comm
except Exception as _: except Exception as _:
pass pass
raise RuntimeError('Unsupported PyTorch version') raise RuntimeError("Unsupported PyTorch version")
...@@ -10,18 +10,27 @@ import os ...@@ -10,18 +10,27 @@ import os
rank = None rank = None
world_size = None world_size = None
dev_name_default = 'cuda:0' dev_name_default = "cuda:0"
class BruteForceMoE(nn.Module): class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, def __init__(
world_size=1, mp_group=None, self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
gate=NaiveGate, top_k=1, pre_lnorm=False): gate=NaiveGate,
assert world_size == 1, 'Distributed brute force is not supported' top_k=1,
pre_lnorm=False,
):
assert world_size == 1, "Distributed brute force is not supported"
super().__init__() super().__init__()
self.mlp = BruteForceMoELinear(activation, num_expert, d_model, self.mlp = BruteForceMoELinear(
d_hidden, 1, top_k) activation, num_expert, d_model, d_hidden, 1, top_k
)
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -43,20 +52,32 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k) ...@@ -43,20 +52,32 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
if rank == 0: if rank == 0:
print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}' print(
.format(MOELayer.__name__, batch_size, in_feat, hidden_feat, "Performance test of {} mm size {} {}x{} experts {}x{} topk {}".format(
world_size, num_expert, top_k)) MOELayer.__name__,
batch_size,
in_feat,
hidden_feat,
world_size,
num_expert,
top_k,
)
)
if world_size > 1: if world_size > 1:
dev_name = 'cuda' dev_name = "cuda"
else: else:
dev_name = dev_name_default dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name) inp = torch.rand(batch_size, in_feat).cuda(dev_name)
inp.requires_grad = True inp.requires_grad = True
moe = MOELayer(num_expert=num_expert, moe = MOELayer(
d_model=in_feat, d_hidden=hidden_feat, num_expert=num_expert,
world_size=world_size, top_k=top_k).cuda(dev_name) d_model=in_feat,
d_hidden=hidden_feat,
world_size=world_size,
top_k=top_k,
).cuda(dev_name)
moe.train() moe.train()
# warm up # warm up
...@@ -64,10 +85,10 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k) ...@@ -64,10 +85,10 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
_ = moe(inp) _ = moe(inp)
n_runs = 16 n_runs = 16
tott = 0. tott = 0.0
backt = 0. backt = 0.0
maxt = 0. maxt = 0.0
sqtot = 0. sqtot = 0.0
for i in range(n_runs): for i in range(n_runs):
ts = time.time() ts = time.time()
o = moe(inp) o = moe(inp)
...@@ -80,36 +101,48 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k) ...@@ -80,36 +101,48 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
bte = time.time() bte = time.time()
tott += te - ts tott += te - ts
sqtot += (te - ts)**2 sqtot += (te - ts) ** 2
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
backt += bte - bts backt += bte - bts
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 + gflops = (
batch_size * in_feat * num_expert) / tott 2e-9
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format( * n_runs
tott * 1e3 / n_runs, maxt * 1e3, * (
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 * top_k / n_runs, in_feat * hidden_feat * batch_size * top_k * 2
backt * 1e3 / n_runs, gflops)) + batch_size * in_feat * num_expert
)
/ tott
if __name__ == '__main__': )
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0') print(
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1') "Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs".format(
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0') tott * 1e3 / n_runs,
if int(os.environ['WORLD_SIZE']) > 1: maxt * 1e3,
torch.distributed.init_process_group(backend='nccl') (sqtot / n_runs - (tott / n_runs) ** 2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs,
gflops,
)
)
if __name__ == "__main__":
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get(
"OMPI_COMM_WORLD_LOCAL_RANK", "0"
)
if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
else: else:
rank = 0 rank = 0
world_size = 1 world_size = 1
batch_size = int(os.environ.get('BATCH_SIZE', '4096')) batch_size = int(os.environ.get("BATCH_SIZE", "4096"))
d_model = int(os.environ.get('D_MODEL', '1024')) d_model = int(os.environ.get("D_MODEL", "1024"))
d_hidden = int(os.environ.get('D_HIDDEN', '4096')) d_hidden = int(os.environ.get("D_HIDDEN", "4096"))
num_expert = int(os.environ.get('NUM_EXPERT', '64')) num_expert = int(os.environ.get("NUM_EXPERT", "64"))
top_k = int(os.environ.get('TOP_K', '2')) top_k = int(os.environ.get("TOP_K", "2"))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model, benchmark_mlp(FMoETransformerMLP, batch_size, d_model, d_hidden, num_expert, top_k)
d_hidden, num_expert, top_k)
if world_size == 1: if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, top_k)
top_k)
...@@ -20,24 +20,19 @@ class BruteForceMoELinear(nn.Module): ...@@ -20,24 +20,19 @@ class BruteForceMoELinear(nn.Module):
self.weight_htoh4 = nn.Parameter( self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model) torch.Tensor(num_expert * world_size, d_hidden, d_model)
) )
self.bias_htoh4 = nn.Parameter( self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden))
torch.Tensor(num_expert * world_size, d_hidden)
)
self.weight_h4toh = nn.Parameter( self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden) torch.Tensor(num_expert * world_size, d_model, d_hidden)
) )
self.bias_h4toh = nn.Parameter( self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model))
torch.Tensor(num_expert * world_size, d_model)
)
self.top_k = top_k self.top_k = top_k
def forward(self, inp, gate_idx, gate_score): def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long() gate_long = gate_idx.long()
batch_size = inp.size(0) batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
device=inp.device)
for i in range(self.weight_htoh4.shape[0]): for i in range(self.weight_htoh4.shape[0]):
idx = (gate_idx == i) idx = gate_idx == i
x = inp[idx] x = inp[idx]
x = x @ self.weight_htoh4[i].t() x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i] x = x + self.bias_htoh4[i]
...@@ -45,8 +40,9 @@ class BruteForceMoELinear(nn.Module): ...@@ -45,8 +40,9 @@ class BruteForceMoELinear(nn.Module):
x = x @ self.weight_h4toh[i].t() x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i] x = x + self.bias_h4toh[i]
o[idx] = x o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k, x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
self.d_model)).reshape(-1, self.d_model) -1, self.d_model
)
return x return x
......
...@@ -18,7 +18,7 @@ class MyMoE(FMoE): ...@@ -18,7 +18,7 @@ class MyMoE(FMoE):
gate=NaiveGate, gate=NaiveGate,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
top_k=top_k top_k=top_k,
) )
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
...@@ -46,5 +46,5 @@ def test_fmoe_dp( ...@@ -46,5 +46,5 @@ def test_fmoe_dp(
output = moe_dp(torch.rand(batch_size, d_model).cuda()) output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == '__main__': if __name__ == "__main__":
test_fmoe_dp(4, 2, 4, 16, 32) test_fmoe_dp(4, 2, 4, 16, 32)
...@@ -68,7 +68,6 @@ class MyMoE(FMoE): ...@@ -68,7 +68,6 @@ class MyMoE(FMoE):
super().__init__( super().__init__(
num_expert=num_expert, num_expert=num_expert,
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=mp_group, mp_group=mp_group,
...@@ -77,8 +76,8 @@ class MyMoE(FMoE): ...@@ -77,8 +76,8 @@ class MyMoE(FMoE):
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234) rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.) _megatron_init_method(self.experts.htoh4, rng, 1.0)
_megatron_init_method(self.experts.h4toh, rng, 1.) _megatron_init_method(self.experts.h4toh, rng, 1.0)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
...@@ -152,8 +151,22 @@ def test_fmoe_linear( ...@@ -152,8 +151,22 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
) )
moe_out_list = moe_out, moe_grad_in, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad moe_out_list = (
raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad moe_out,
moe_grad_in,
moe.experts.htoh4.weight.grad,
moe.experts.h4toh.weight.grad,
moe.experts.htoh4.bias.grad,
moe.experts.h4toh.bias.grad,
)
raw_out_list = (
raw_out,
raw_grad_in,
moe_raw.weight_htoh4.grad,
moe_raw.weight_h4toh.grad,
moe_raw.bias_htoh4.grad,
moe_raw.bias_h4toh.grad,
)
if world_size > 1: if world_size > 1:
_, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
...@@ -176,7 +189,14 @@ def test_fmoe_linear( ...@@ -176,7 +189,14 @@ def test_fmoe_linear(
) )
raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"] names = [
"output",
"input grad",
"htoh4 weight grad",
"h4toh weight grad",
"htoh4 bias grad",
"h4toh bias grad",
]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numercial(names, moe_out_list, raw_out_list, rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment