Unverified Commit baae8fb9 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #31 from laekov/gate

Reconstruct gate and add gshard / switch
parents 3c42c892 8d14dd29
#include "../balancing.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int cap_v = atoi(args[3]);
int tot_expert = n_worker * n_expert;
long* lec = new long[tot_expert];
for (int i = 0; i < tot_expert; ++i) {
lec[i] = i;
}
long* g_lec;
cudaMalloc(&g_lec, sizeof(long) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice);
int* cap = new int[n_expert];
for (int i = 0; i < n_expert; ++i) {
cap[i] = cap_v;
}
int* g_cap;
cudaMalloc(&g_cap, sizeof(int) * n_expert);
cudaMemcpy(g_cap, cap, sizeof(int) * n_expert, cudaMemcpyHostToDevice);
long* eca = new long[tot_expert];
long* g_eca;
cudaMalloc(&g_eca, sizeof(long) * tot_expert);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_limit_by_capacity_impl(g_lec, g_cap, g_eca, n_expert, n_worker, smgr);
cudaMemcpy(cap, g_cap, sizeof(int) * n_expert, cudaMemcpyDeviceToHost);
cudaMemcpy(eca, g_eca, sizeof(long) * tot_expert, cudaMemcpyDeviceToHost);
printf("%d\n", cap[0]);
for (int i = 0; i < tot_expert; ++i) {
printf("%ld %ld\n", lec[i], eca[i]);
}
}
#include "../balancing.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int batch_size = atoi(args[3]);
int tot_expert = n_worker * n_expert;
long* gate_idx = new long[batch_size];
long* n_gate_idx = new long[batch_size];
int* lec = new int[tot_expert];
memset(lec, 0, sizeof(int) * tot_expert);
for (int i = 0; i < batch_size; ++i) {
gate_idx[i] = rand() % tot_expert;
++lec[gate_idx[i]];
}
for (int i = 0; i < tot_expert; ++i) {
lec[i] >>= 1;
}
int* g_lec;
cudaMalloc(&g_lec, sizeof(int) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice);
long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_prune_gate_by_capacity_impl(g_gate_idx, g_lec,
batch_size, n_expert, n_worker, smgr);
cudaMemcpy(n_gate_idx, g_gate_idx, sizeof(long) * batch_size, cudaMemcpyDeviceToHost);
for (int i = 0; i < batch_size; ++i) {
printf("%ld %ld (%d)\n", gate_idx[i], n_gate_idx[i], lec[gate_idx[i]]);
}
}
......@@ -85,11 +85,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const c10::Half *beta,
c10::Half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k,
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
}
#endif // CUBLAS_WRAPPER_H
#ifndef FMOE_UTILS_H
#define FMOE_UTILS_H
#define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#endif // FMOE_UTILS_H
......@@ -31,6 +31,7 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <stdio.h>
#ifndef HELPER_CUDA_H
#define HELPER_CUDA_H
......
#ifndef TIMER_HH
#define TIMER_HH
/*
* This part of code is not used.
#include <chrono>
inline double getDuration(std::chrono::time_point<std::chrono::system_clock> a,
std::chrono::time_point<std::chrono::system_clock> b) {
return std::chrono::duration<double>(b - a).count();
std::chrono::time_point<std::chrono::system_clock> b) {
return std::chrono::duration<double>(b - a).count();
}
#define timestamp(__var__) auto __var__ = std::chrono::system_clock::now();
#include <chrono>
*/
#endif // TIMER_HH
......@@ -10,36 +10,58 @@ import fmoe_cuda
from .utils import get_torch_default_comm
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
r"""
Prepare necessary information from gate output for MoE computation.
def _ensure_nccl(t, comm=None):
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t)
Args:
gate: a 1-d Long Tensor representing the target expert of each input
sample.
num_expert: number of experts on each worker.
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
if world_size > 1:
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, gate)
def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
flatten_gate = gate.view(-1)
eff_gate = flatten_gate[flatten_gate != -1]
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long
)
local_expert_count.index_put_((gate_idx.long(),), gate_count)
ones = torch.ones(eff_gate.numel(),
device=gate.device, dtype=torch.long)
local_expert_count.index_add_(0, eff_gate, ones)
if world_size > 1:
_ensure_nccl(gate)
(global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
else:
global_expert_count = local_expert_count
if not require_pos:
pos = None
else:
lec_cum = torch.cumsum(local_expert_count, dim=0).int()
pos_size = lec_cum[-1].item()
pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
fmoe_cuda.assign_pos_(lec_cum, gate, pos)
return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm=None):
r"""
Prepare necessary information from gate output for MoE computation.
Args:
gate: a 1-d Long Tensor representing the target expert of each input
sample.
num_expert: number of experts on each worker.
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
if world_size > 1:
_ensure_nccl(gate, comm=comm)
pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size)
with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
......@@ -52,6 +74,21 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
def _local_scatter(inp, pos):
inp_buf = torch.index_select(inp, 0, pos)
return inp_buf
def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
dtype=inp.dtype, device=inp.device)
if maybe_overlap:
inp_buf.index_add_(0, pos, inp)
else:
inp_buf.index_copy_(0, pos, inp)
return inp_buf
class MOEScatter(Function):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
......@@ -69,7 +106,7 @@ class MOEScatter(Function):
fwd_batch_size,
world_size,
):
(local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
(global_input_buf,) = fmoe_cuda.global_scatter(
local_input_buf,
......@@ -80,7 +117,7 @@ class MOEScatter(Function):
)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size
ctx.moe_args = inp.shape[0], pos.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
......@@ -88,19 +125,19 @@ class MOEScatter(Function):
@staticmethod
def backward(ctx, global_grad_in):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensors
(local_batch_size, world_size) = ctx.moe_args
(inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
if world_size > 1:
(local_grad_in,) = fmoe_cuda.global_gather(
global_grad_in,
local_expert_count,
global_expert_count,
local_batch_size,
buf_batch_size,
world_size,
)
else:
local_grad_in = global_grad_in
(grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None
......@@ -111,7 +148,7 @@ class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward(
(global_output_buf,) = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
......@@ -121,7 +158,7 @@ class MOELinear(Function):
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.backward(
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
......@@ -157,7 +194,8 @@ class MOEGather(Function):
)
else:
local_output_buf = global_output_buf
(output,) = fmoe_cuda.local_gather(local_output_buf, pos)
output = _local_gather(local_output_buf, pos, local_batch_size,
maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size)
variables = (pos, local_expert_count, global_expert_count)
......@@ -168,7 +206,7 @@ class MOEGather(Function):
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
(grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
if world_size > 1:
(global_grad_out_buf,) = fmoe_cuda.global_scatter(
grad_out_buf,
......
r"""
Different implementations of the Gate are located in separate files here.
"""
from .zero_gate import ZeroGate
from .naive_gate import NaiveGate
from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
r"""
Base gate with standard interface
"""
import torch.nn as nn
class BaseGate(nn.Module):
def __init__(self, num_expert, world_size):
super().__init__()
self.world_size = world_size
self.num_expert = num_expert
self.tot_expert = world_size * num_expert
self.loss = None
def forward(self, x):
raise NotImplementedError('Base gate cannot be directly used for fwd')
def set_loss(self, loss):
self.loss = loss
def get_loss(self, clear=True):
loss = self.loss
if clear:
self.loss = None
return loss
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0]
top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity)
return topk_idx, topk_val
r"""
Naive gate
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class NaiveGate(BaseGate):
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indicies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k
def forward(self, inp, return_all_scores=False):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1)
if return_all_scores:
return gate_top_k_idx, gate_top_k_val, gate
return gate_top_k_idx, gate_top_k_val
r"""
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
Noisy gate for gshard and switch
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
class ZeroGate(nn.Module):
r"""
Guide all input samples to gate 0.
"""
def __init__(self, _1, num_expert, _3, top_k=2):
super().__init__()
self.num_expert = num_expert
self.top_k = top_k
def forward(self, inp):
r"""
All output to expert 1
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
class NaiveGate(nn.Module):
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score, gate
class NoisyGate(nn.Module):
class NoisyGate(BaseGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.num_expert = num_expert * world_size
super().__init__(num_expert, world_size)
self.w_gate = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.top_k = top_k
self.softplus = nn.Softplus()
......@@ -163,7 +105,7 @@ class NoisyGate(nn.Module):
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(
min(self.top_k + 1, self.num_expert), dim=1
min(self.top_k + 1, self.tot_expert), dim=1
)
top_k_logits = top_logits[:, : self.top_k]
top_k_indices = top_indices[:, : self.top_k]
......@@ -172,7 +114,7 @@ class NoisyGate(nn.Module):
zeros = torch.zeros_like(logits, requires_grad=True)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
if self.top_k < self.num_expert:
if self.top_k < self.tot_expert:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
......@@ -183,9 +125,9 @@ class NoisyGate(nn.Module):
importance = gates.sum(0)
loss = self.cv_squared(importance) + self.cv_squared(load)
self.set_loss(loss)
return (
top_k_indices.contiguous().view(-1),
top_k_gates.contiguous().unsqueeze(1),
loss,
)
r"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class SwitchGate(NaiveGate):
r"""
A switch gate implementation
"""
def __init__(self, d_model, num_expert, world_size,
switch_eps=.1, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=1)
self.switch_eps = switch_eps
self.capacity = capacity
def forward(self, inp):
r"""
The switch firstly conduct softmax and then calculates the top-1
"""
score = self.gate(inp)
if self.training:
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(score)
noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
score += noise
# fp32 softmax for numerical stability
score = F.softmax(score.float(), dim=-1)
top1_score, top1_idx = torch.topk(
score, k=1, dim=-1, largest=True
) # [.. x top_k]
top1_score = top1_score.to(dtype=inp.dtype)
top1_score = top1_score.to(dtype=inp.dtype)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0])
limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity)
valid_idx = top1_idx[top1_idx > -1]
fraction_expert = torch.scatter_add(
torch.zeros(self.tot_expert, device=valid_idx.device),
0,
valid_idx,
torch.ones_like(valid_idx, dtype=torch.float),
) / valid_idx.numel()
prob_expert = score.sum(dim=0) / valid_idx.numel()
loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(loss)
return top1_idx, top1_score
r"""
Utilities that may be used in the gates
"""
import torch
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
require_pos=False)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
new_lec, = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec
r"""
Zero gate that direct all input to gate 0
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroGate(BaseGate):
r"""
Guide all input samples to gate 0.
"""
def __init__(self, _1, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.top_k = top_k
def forward(self, inp):
r"""
All output to expert 1
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
return idx, gate_score.reshape(-1, 1, self.top_k)
......@@ -4,7 +4,7 @@ Layers that FMoE provides to users
import torch
import torch.nn as nn
from .functions import moe_prepare_forward
from .functions import prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather, Slice
from .gates import NaiveGate
......@@ -82,14 +82,24 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size)
) = prepare_forward(gate, num_expert, world_size)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
x = MOEScatter.apply(
inp, pos,
inp, pos // topk,
local_expert_count, global_expert_count, fwd_batch_size, world_size
)
x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
x, pos,
local_expert_count, global_expert_count,
out_batch_size, world_size
)
return x
......@@ -184,17 +194,16 @@ class FMoE(nn.Module):
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score, gate_state_dict = self.gate(inp)
if self.gate_hook:
self.gate_hook(gate_top_k_idx, gate_score, gate_state_dict)
# to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_top_k_idx, gate_score = self.gate(inp)
x = _fmoe_general_global_forward(
inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
inp,
gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
)
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model
x = x.view(inp.shape[0], self.top_k, self.d_model)
gate_score = gate_score.view(inp.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
......
......@@ -7,7 +7,7 @@ cxx_flags = []
ext_libs = []
if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL')
cxx_flags.append('-DFMOE_USE_NCCL')
ext_libs.append('nccl')
......@@ -20,16 +20,17 @@ if __name__ == '__main__':
author_email='hja20@mails.tsinghua.edu.cn',
license='Apache-2',
url='https://github.com/laekov/fastmoe',
packages=['fmoe', 'fmoe.megatron'],
packages=['fmoe', 'fmoe.megatron', 'fmoe.gates'],
ext_modules=[
CUDAExtension(
name='fmoe_cuda',
sources=[
'cuda/moe.cpp',
'cuda/cuda_stream_manager.cpp',
'cuda/moe_compute_kernel.cu',
'cuda/moe_comm_kernel.cu',
'cuda/moe_fused_kernel.cu',
'cuda/stream_manager.cpp',
'cuda/local_exchange.cu',
'cuda/balancing.cu',
'cuda/global_exchange.cpp',
'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp',
],
extra_compile_args={
'cxx': cxx_flags,
......
......@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module):
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
for i in range(self.weight_htoh4.shape[0]):
idx = gate_idx == i
idx = gate_long == i
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
......@@ -40,6 +41,8 @@ class BruteForceMoELinear(nn.Module):
x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
......@@ -55,11 +58,13 @@ class BruteForceMoE(nn.Module):
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
x[i] = self.experts[gate_long[i]](inp[i])
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
......
......@@ -11,7 +11,7 @@ from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _run_distributed(func, world_size, args: Dict):
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
import subprocess
......@@ -25,7 +25,7 @@ def _run_distributed(func, world_size, args: Dict):
for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen(
[sys.executable, __file__, func, json.dumps(args)], stdout=subprocess.PIPE
[sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
)
ps.append(p)
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
if 1 * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_gshard_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
def _test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .8])
def test_switch_gate(d_model, batch_size, n_expert, cap):
_run_distributed('_test_switch_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
def _test_switch_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
_ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
test_gshard_gate(8, 16, 1, .1)
# test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment