Unverified Commit 5e5b4044 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #9 from laekov/laekov/accfix

Laekov/accfix
parents 1cfc5462 ba878d29
......@@ -10,6 +10,3 @@ a.out
build
*swp
logs
examples/transformer-xl/data
examples/data
examples/transformer-xl/LM-TFM-enwik8
transformer-xl/data
transformer-xl/LM-TFM-enwik8
data
......@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module):
groups[group_key] = [p]
else:
groups[group_key].append(p)
for (dp_comm, dtype), group in groups.items():
for (dp_comm, _), group in groups.items():
if dp_comm not in self.comms:
continue
comm = self.comms[dp_comm]
datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
torch.distributed.broadcast(coalesced, 0, group=comm)
torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas)
......
......@@ -7,6 +7,25 @@ import torch.nn as nn
import torch.nn.functional as F
class ZeroGate(nn.Module):
r'''
Guide all input samples to gate 0.
'''
def __init__(self, _1, _2, _3, top_k=2):
super().__init__()
self.top_k = top_k
def forward(self, inp):
r'''
All output to expert 1
'''
idx = torch.zeros(inp.shape[0] * self.top_k,
dtype=torch.int64, device=inp.device)
score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
class NaiveGate(nn.Module):
r'''
A naive gate implementation that defines the standard behavior of the gate
......
r'''
Layers that FMoE provides to users
'''
import math
import torch
import torch.nn as nn
import numpy as np
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
......@@ -31,29 +29,6 @@ class FMoELinear(nn.Module):
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
r'''
Initialize the weight as linear layers
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count):
r'''
......@@ -175,6 +150,10 @@ class FMoE(nn.Module):
self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count):
r'''
The default expert function which either calls the experts as a whole
or as separate experts.
'''
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
outputs = []
......
......@@ -3,9 +3,52 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm
class _FakeMegatronMLP(nn.Module):
r'''
A fake mlp without model parallelism for correctness testing
'''
def __init__(self, args, _):
super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x):
r'''
Directly use GeLU
'''
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
return x, torch.zeros_like(x)
def _random_init_weight(self, rng):
r'''
Copied from torch.nn.init.kaiming_uniform_
'''
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
class MegatronMLP(FMoETransformerMLP):
......@@ -26,10 +69,23 @@ class MegatronMLP(FMoETransformerMLP):
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
self.hidden_size = args.hidden_size
self.rank = args.rank
self.reset_parameters()
def reset_parameters(self):
r'''
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_random_init_weight(self.experts.htoh4, rng)
_random_init_weight(self.experts.h4toh, rng)
def forward(self, inp):
output = super().forward(inp)
bias = output.new_zeros(output.size(-1), requires_grad=False)
return output, bias
return super().forward(inp), torch.zeros(self.hidden_size,
dtype=inp.dtype, device=inp.device)
def fmoefy(model, num_experts=None, distributed_experts=True,
......@@ -49,6 +105,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
tensor_model_parall_comm x data_parallel_comm, which is not created.
'''
from megatron import get_args
from megatron import mpu
args = get_args()
if num_experts is not None:
args.num_experts = num_experts
......@@ -71,7 +128,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers:
l.mlp = MegatronMLP(args, get_torch_default_comm())
l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
return model
......
......@@ -47,7 +47,7 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm='none',
expert_dp_comm='none'
):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment