Commit ba878d29 authored by Rick Ho's avatar Rick Ho
Browse files

fix lint

parent 66f7166d
...@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module):
groups[group_key] = [p] groups[group_key] = [p]
else: else:
groups[group_key].append(p) groups[group_key].append(p)
for (dp_comm, dtype), group in groups.items(): for (dp_comm, _), group in groups.items():
if dp_comm not in self.comms: if dp_comm not in self.comms:
continue continue
comm = self.comms[dp_comm] comm = self.comms[dp_comm]
datas = [p.data for p in group] datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas) coalesced = _flatten_dense_tensors(datas)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
torch.distributed.broadcast(coalesced, 0, group=comm) torch.distributed.broadcast(coalesced, 0, group=comm)
torch.cuda.synchronize() torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas) synced = _unflatten_dense_tensors(coalesced, datas)
......
...@@ -8,14 +8,16 @@ import torch.nn.functional as F ...@@ -8,14 +8,16 @@ import torch.nn.functional as F
class ZeroGate(nn.Module): class ZeroGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2): r'''
Guide all input samples to gate 0.
'''
def __init__(self, _1, _2, _3, top_k=2):
super().__init__() super().__init__()
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r''' r'''
The naive implementation simply calculates the top-k of a linear layer's All output to expert 1
output.
''' '''
idx = torch.zeros(inp.shape[0] * self.top_k, idx = torch.zeros(inp.shape[0] * self.top_k,
dtype=torch.int64, device=inp.device) dtype=torch.int64, device=inp.device)
......
...@@ -150,6 +150,10 @@ class FMoE(nn.Module): ...@@ -150,6 +150,10 @@ class FMoE(nn.Module):
self.experts_fused = True self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
r'''
The default expert function which either calls the experts as a whole
or as separate experts.
'''
if self.experts_fused: if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
outputs = [] outputs = []
......
...@@ -3,22 +3,28 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two ...@@ -3,22 +3,28 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
''' '''
import math
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .transformer import FMoETransformerMLP from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
r''' r'''
A fake mlp without model parallelism for correctness testing A fake mlp without model parallelism for correctness testing
''' '''
def __init__(self, args, group): def __init__(self, args, _):
super().__init__() super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size) self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size) self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x): def forward(self, x):
r'''
Directly use GeLU
'''
x = self.fc1(x) x = self.fc1(x)
x = F.gelu(x) x = F.gelu(x)
x = self.fc2(x) x = self.fc2(x)
...@@ -71,7 +77,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -71,7 +77,7 @@ class MegatronMLP(FMoETransformerMLP):
r''' r'''
Initialize the weight as linear layers. Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used. additional numpy rng is used.
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_random_init_weight(self.experts.htoh4, rng) _random_init_weight(self.experts.htoh4, rng)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment