Commit fdbac1df authored by Sengxian's avatar Sengxian
Browse files

Format using black and add model_parallel_rank

parent ae658b89
......@@ -12,31 +12,48 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(num_expert * world_size,
device=gate.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long
)
local_expert_count.index_put_((gate_idx.long(),), gate_count)
if world_size > 1:
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
(global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
return (pos, local_expert_count.cpu(), global_expert_count.cpu(),
fwd_expert_count.cpu(), fwd_batch_size)
return (
pos,
local_expert_count.cpu(),
global_expert_count.cpu(),
fwd_expert_count.cpu(),
fwd_batch_size,
)
class MOEScatter(Function):
@staticmethod
def forward(ctx, inp, pos, local_expert_count, global_expert_count,
fwd_batch_size, world_size):
local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
def forward(
ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
):
(local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
if world_size > 1:
global_input_buf, = fmoe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
(global_input_buf,) = fmoe_cuda.global_scatter(
local_input_buf,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else:
global_input_buf = local_input_buf
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
......@@ -50,20 +67,25 @@ class MOEScatter(Function):
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1:
local_grad_in, = fmoe_cuda.global_gather(global_grad_in,
local_expert_count, global_expert_count,
local_batch_size, world_size)
(local_grad_in,) = fmoe_cuda.global_gather(
global_grad_in,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
)
else:
local_grad_in = global_grad_in
grad_in, = fmoe_cuda.local_gather(local_grad_in, pos)
(grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
return grad_in, None, None, None, None, None
class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
(global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count
)
variables = (global_input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables)
return global_output_buf
......@@ -72,21 +94,33 @@ class MOELinear(Function):
def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count)
grad_out, input_buf, weight, fwd_expert_count
)
return grad_inp_buf, grad_weight, None
class MOEGather(Function):
@staticmethod
def forward(ctx, global_output_buf, pos, local_expert_count,
global_expert_count, local_batch_size, world_size):
def forward(
ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
):
if world_size > 1:
local_output_buf, = fmoe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
(local_output_buf,) = fmoe_cuda.global_gather(
global_output_buf,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
)
else:
local_output_buf = global_output_buf
output, = fmoe_cuda.local_gather(local_output_buf, pos)
(output,) = fmoe_cuda.local_gather(local_output_buf, pos)
ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
......@@ -97,13 +131,15 @@ class MOEGather(Function):
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
(grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
(global_grad_out_buf,) = fmoe_cuda.global_scatter(
grad_out_buf,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None, None, None
......@@ -9,8 +9,7 @@ class FMoELinear(nn.Module):
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
......@@ -30,8 +29,9 @@ class FMoENaiveGate(nn.Module):
def forward(self, inp):
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
......@@ -42,23 +42,38 @@ class FMoENaiveGate(nn.Module):
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
(pos, local_expert_count, global_expert_count, fwd_expert_count,
fwd_batch_size) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(inp, pos, local_expert_count, global_expert_count,
fwd_batch_size, world_size)
(
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
)
for i, l in enumerate(linears):
if i:
x = activation(x)
x = l(x, fwd_expert_count)
x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
inp.shape[0], world_size)
x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
)
return x
class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=1, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
activation=torch.nn.functional.gelu,
top_k=2,
pre_lnorm=False,
model_parallel_rank=-1,
):
super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
......@@ -74,8 +89,10 @@ class FMoETransformerMLP(nn.Module):
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
self.bias = torch.nn.parameter.Parameter(
torch.zeros(d_model, dtype=torch.float32)
)
self.model_parallel_rank = model_parallel_rank
def forward(self, inp):
residual = inp
......@@ -85,11 +102,17 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
# TODO: merge replication into local_scatter
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = _fmoe_full_forward(inp, gate_top_k_idx,
[self.htoh4, self.h4toh], self.activation,
self.num_expert, self.world_size)
inp = inp.view(-1, self.d_model).repeat_interleave(
repeats=self.top_k, dim=0
) # (BxLxtop_k) x d_model
x = _fmoe_full_forward(
inp,
gate_top_k_idx,
[self.htoh4, self.h4toh],
self.activation,
self.num_expert,
self.world_size,
)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
......@@ -99,4 +122,3 @@ class FMoETransformerMLP(nn.Module):
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
......@@ -2,11 +2,15 @@ from .layers import FMoETransformerMLP
def create_moe_mlp(args):
assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
assert (
args.num_experts % args.model_parallel_size == 0
), "Num experts should be multiple of mp size"
num_experts = args.num_experts // args.model_parallel_size
fmoe = FMoETransformerMLP(num_experts,
fmoe = FMoETransformerMLP(
num_experts,
d_model=args.hidden_size,
d_hidden=args.hidden_size * 4,
world_size=args.model_parallel_size)
world_size=args.model_parallel_size,
model_parallel_rank=args.model_parallel_rank,
)
return fmoe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment