"vscode:/vscode.git/clone" did not exist on "b0dbbd7a6439f2248e194f96ddb58d427670ca56"
Commit ba878d29 authored by Rick Ho's avatar Rick Ho
Browse files

fix lint

parent 66f7166d
...@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module):
groups[group_key] = [p] groups[group_key] = [p]
else: else:
groups[group_key].append(p) groups[group_key].append(p)
for (dp_comm, dtype), group in groups.items(): for (dp_comm, _), group in groups.items():
if dp_comm not in self.comms: if dp_comm not in self.comms:
continue continue
comm = self.comms[dp_comm] comm = self.comms[dp_comm]
datas = [p.data for p in group] datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas) coalesced = _flatten_dense_tensors(datas)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
torch.distributed.broadcast(coalesced, 0, group=comm) torch.distributed.broadcast(coalesced, 0, group=comm)
torch.cuda.synchronize() torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas) synced = _unflatten_dense_tensors(coalesced, datas)
......
...@@ -8,14 +8,16 @@ import torch.nn.functional as F ...@@ -8,14 +8,16 @@ import torch.nn.functional as F
class ZeroGate(nn.Module): class ZeroGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2): r'''
Guide all input samples to gate 0.
'''
def __init__(self, _1, _2, _3, top_k=2):
super().__init__() super().__init__()
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r''' r'''
The naive implementation simply calculates the top-k of a linear layer's All output to expert 1
output.
''' '''
idx = torch.zeros(inp.shape[0] * self.top_k, idx = torch.zeros(inp.shape[0] * self.top_k,
dtype=torch.int64, device=inp.device) dtype=torch.int64, device=inp.device)
......
...@@ -150,6 +150,10 @@ class FMoE(nn.Module): ...@@ -150,6 +150,10 @@ class FMoE(nn.Module):
self.experts_fused = True self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
r'''
The default expert function which either calls the experts as a whole
or as separate experts.
'''
if self.experts_fused: if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
outputs = [] outputs = []
......
...@@ -3,22 +3,28 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two ...@@ -3,22 +3,28 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
''' '''
import math
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .transformer import FMoETransformerMLP from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm
class _FakeMegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
r''' r'''
A fake mlp without model parallelism for correctness testing A fake mlp without model parallelism for correctness testing
''' '''
def __init__(self, args, group): def __init__(self, args, _):
super().__init__() super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size) self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size) self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x): def forward(self, x):
r'''
Directly use GeLU
'''
x = self.fc1(x) x = self.fc1(x)
x = F.gelu(x) x = F.gelu(x)
x = self.fc2(x) x = self.fc2(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment