Commit f2040d9f authored by Rick Ho's avatar Rick Ho
Browse files

pass pylint

parent bf2fd0c0
...@@ -138,7 +138,10 @@ disable=print-statement, ...@@ -138,7 +138,10 @@ disable=print-statement,
xreadlines-attribute, xreadlines-attribute,
deprecated-sys-function, deprecated-sys-function,
exception-escape, exception-escape,
comprehension-escape comprehension-escape,
arguments-differ,
import-outside-toplevel,
signature-differs,
# Enable the message, report, category or checker with the given id(s). You can # Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option
...@@ -398,7 +401,7 @@ indent-after-paren=4 ...@@ -398,7 +401,7 @@ indent-after-paren=4
indent-string=' ' indent-string=' '
# Maximum number of characters on a single line. # Maximum number of characters on a single line.
max-line-length=100 max-line-length=81
# Maximum number of lines in a module. # Maximum number of lines in a module.
max-module-lines=1000 max-module-lines=1000
...@@ -553,7 +556,7 @@ preferred-modules= ...@@ -553,7 +556,7 @@ preferred-modules=
max-args=12 max-args=12
# Maximum number of attributes for a class (see R0902). # Maximum number of attributes for a class (see R0902).
max-attributes=7 max-attributes=32
# Maximum number of boolean expressions in an if statement (see R0916). # Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5 max-bool-expr=5
......
r'''
Supportive modules to conduct distributed training
'''
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -5,11 +8,24 @@ from .utils import get_torch_default_comm ...@@ -5,11 +8,24 @@ from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module): class DistributedGroupedDataParallel(nn.Module):
r'''
A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object.
The grads of the weights are identified to be reduced in different groups
according to the weigths' `dp_comm` attribute.
If it is set to `dp`, it will only be reduced across the data-parallel
groups, which means that in the model parallel group, they are not
synchronized.
If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for
shared layers like the gate.
'''
def __init__(self, module, mp_group=None, dp_group=None, world_group=None, def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
auto_allreduce=False): auto_allreduce=False):
assert not auto_allreduce, 'Automatic all-reduce is not implemented yet' assert not auto_allreduce, 'Automatic all-reduce is not implemented yet'
super(DistributedGroupedDataParallel, self).__init__() super().__init__()
self.module = module self.module = module
self.comms = dict() self.comms = dict()
...@@ -39,10 +55,9 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -39,10 +55,9 @@ class DistributedGroupedDataParallel(nn.Module):
groups[group_key] = [p] groups[group_key] = [p]
else: else:
groups[group_key].append(p) groups[group_key].append(p)
for dp_comm, dtype in groups: for (dp_comm, dtype), group in groups.items():
if dp_comm not in self.comms: if dp_comm not in self.comms:
continue continue
group = groups[dp_comm, dtype]
comm = self.comms[dp_comm] comm = self.comms[dp_comm]
grads = [p.grad.data for p in group] grads = [p.grad.data for p in group]
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
...@@ -61,5 +76,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -61,5 +76,7 @@ class DistributedGroupedDataParallel(nn.Module):
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
r'''
Directly call the module's forward function.
'''
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
...@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
) )
else: else:
global_expert_count = local_expert_count global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0) fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
return ( return (
pos, pos,
...@@ -175,6 +176,9 @@ class MOEGather(Function): ...@@ -175,6 +176,9 @@ class MOEGather(Function):
class AllGather(Function): class AllGather(Function):
r'''
A wrapper for the All-Gather function to support auto-differentiation.
'''
@staticmethod @staticmethod
def forward(ctx, inp, rank, world_size, group): def forward(ctx, inp, rank, world_size, group):
tensor_list = [torch.empty_like(inp) for _ in range(world_size)] tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
......
from .functions import * r'''
Layers that FMoE provides to users
'''
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather
class FMoELinear(nn.Module): class FMoELinear(nn.Module):
r'''
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(FMoELinear, self).__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -13,21 +26,40 @@ class FMoELinear(nn.Module): ...@@ -13,21 +26,40 @@ class FMoELinear(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
r'''
Initialize the weight as linear layers
'''
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r'''
Call MOE function
'''
return MOELinear.apply(inp, self.weight, fwd_expert_count) return MOELinear.apply(inp, self.weight, fwd_expert_count)
class FMoENaiveGate(nn.Module): class FMoENaiveGate(nn.Module):
r'''
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
'''
def __init__(self, d_model, num_expert, world_size, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super(FMoENaiveGate, self).__init__() super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size) self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
gate = self.gate(inp) gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk( gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False gate, k=self.top_k, dim=-1, largest=True, sorted=False
...@@ -42,15 +74,25 @@ class FMoENaiveGate(nn.Module): ...@@ -42,15 +74,25 @@ class FMoENaiveGate(nn.Module):
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
r'''
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the MLP of the experts by applying MoELinear and the activation in
turns.
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
'''
( (
pos, pos, local_expert_count, global_expert_count, fwd_expert_count,
local_expert_count, fwd_batch_size
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size) ) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply( x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
world_size
) )
for i, l in enumerate(linears): for i, l in enumerate(linears):
if i: if i:
...@@ -63,6 +105,19 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): ...@@ -63,6 +105,19 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
class FMoETransformerMLP(nn.Module): class FMoETransformerMLP(nn.Module):
r'''
A complete MoE MLP module in a Transformer block.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `mp_group` can be a torch's communication group, indicating that model
parallel is applied across the group, which means that workers in the group
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `activation` is the activation function to be used in MLP in each expert.
* `top_k` stands for the number of experts each token is going to.
'''
def __init__( def __init__(
self, self,
num_expert=32, num_expert=32,
...@@ -72,9 +127,9 @@ class FMoETransformerMLP(nn.Module): ...@@ -72,9 +127,9 @@ class FMoETransformerMLP(nn.Module):
mp_group=None, mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
top_k=2, top_k=2,
pre_lnorm=False, pre_lnorm=False
): ):
super(FMoETransformerMLP, self).__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
self.d_hidden = d_hidden self.d_hidden = d_hidden
...@@ -103,6 +158,11 @@ class FMoETransformerMLP(nn.Module): ...@@ -103,6 +158,11 @@ class FMoETransformerMLP(nn.Module):
) )
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
r'''
The FMoETransformerMLP module automatically performs reshape and layer
normalization. The score of the selected gate given by the expert is
multiplied to the experts' output tensors as a weight.
'''
original_shape = inp.shape original_shape = inp.shape
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
......
'''
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
from .layers import FMoETransformerMLP from .layers import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
def create_moe_mlp(args, group): def _create_moe_mlp(args, group):
assert ( r'''
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size == 0 Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
assert (args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
...@@ -23,6 +32,17 @@ def create_moe_mlp(args, group): ...@@ -23,6 +32,17 @@ def create_moe_mlp(args, group):
def fmoefy(model, num_experts=None, distributed_experts=True): def fmoefy(model, num_experts=None, distributed_experts=True):
r'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
array of transformer blocks that contain an `mlp` member.
* `distributed_expert` is set to True if different experts are located in
different workers. Otherwise, the experts on the workers are identical, and
they are trained in data-parallel mode. This can be useful when testing on
small models that do not require high training throughput or large parameter
capacity.
'''
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
args = get_args() args = get_args()
...@@ -37,24 +57,38 @@ def fmoefy(model, num_experts=None, distributed_experts=True): ...@@ -37,24 +57,38 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers: for l in model.language_model.transformer.layers:
l.mlp = create_moe_mlp(args, mpu.get_model_parallel_group()) l.mlp = _create_moe_mlp(args, mpu.get_model_parallel_group())
return model return model
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
r'''
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
'''
def __init__(self, module): def __init__(self, module):
from megatron import mpu from megatron import mpu
super(DistributedDataParallel, self).__init__( super().__init__(
module, module,
mp_group=mpu.get_model_parallel_group(), mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group() dp_group=mpu.get_data_parallel_group()
) )
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
r'''
Keep consitency with Megatron
'''
return self.module.state_dict(*args, **kwargs) return self.module.state_dict(*args, **kwargs)
def state_dict_for_save_checkpoint(self, *args, **kwargs): def state_dict_for_save_checkpoint(self, *args, **kwargs):
r'''
Keep consitency with Megatron
'''
return self.module.state_dict_for_save_checkpoint(*args, **kwargs) return self.module.state_dict_for_save_checkpoint(*args, **kwargs)
def load_state_dict(self, *args, **kwargs): def load_state_dict(self, *args, **kwargs):
r'''
Keep consitency with Megatron
'''
return self.module.load_state_dict(*args, **kwargs) return self.module.load_state_dict(*args, **kwargs)
r'''
Utils to play with PyTorch.
'''
import torch.distributed as dist import torch.distributed as dist
# pylint: disable=broad-except
# pylint: disable=protected-access
def get_torch_default_comm(): def get_torch_default_comm():
r'''
The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
in Fast MoE's C code takes the `_default_pg` and tries to dig the
communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch.
'''
try: try:
comm = dist.distributed_c10d._get_default_group() comm = dist.distributed_c10d._get_default_group()
return comm return comm
except Exception as e: except Exception as _:
print('Error {}'.format(e))
pass pass
try: try:
comm = dist.distributed_c10d._default_pg comm = dist.distributed_c10d._default_pg
...@@ -15,6 +28,3 @@ def get_torch_default_comm(): ...@@ -15,6 +28,3 @@ def get_torch_default_comm():
except Exception as _: except Exception as _:
pass pass
raise RuntimeError('Unsupported PyTorch version') raise RuntimeError('Unsupported PyTorch version')
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment