Commit fdbac1df authored by Sengxian's avatar Sengxian
Browse files

Format using black and add model_parallel_rank

parent ae658b89
...@@ -12,31 +12,48 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -12,31 +12,48 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
with torch.no_grad(): with torch.no_grad():
_, pos = torch.sort(gate) _, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True) gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(num_expert * world_size, local_expert_count = torch.zeros(
device=gate.device, dtype=torch.long) num_expert * world_size, device=gate.device, dtype=torch.long
local_expert_count.index_put_((gate_idx.long(), ), gate_count) )
local_expert_count.index_put_((gate_idx.long(),), gate_count)
if world_size > 1: if world_size > 1:
global_expert_count, = fmoe_cuda.expert_exchange( (global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size) local_expert_count, num_expert, world_size
)
else: else:
global_expert_count = local_expert_count global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
return (pos, local_expert_count.cpu(), global_expert_count.cpu(), return (
fwd_expert_count.cpu(), fwd_batch_size) pos,
local_expert_count.cpu(),
global_expert_count.cpu(),
fwd_expert_count.cpu(),
fwd_batch_size,
)
class MOEScatter(Function): class MOEScatter(Function):
@staticmethod @staticmethod
def forward(ctx, inp, pos, local_expert_count, global_expert_count, def forward(
fwd_batch_size, world_size): ctx,
local_input_buf, = fmoe_cuda.local_scatter(inp, pos) inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
):
(local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
if world_size > 1: if world_size > 1:
global_input_buf, = fmoe_cuda.global_scatter(local_input_buf, (global_input_buf,) = fmoe_cuda.global_scatter(
local_expert_count, global_expert_count, local_input_buf,
fwd_batch_size, world_size) local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else: else:
global_input_buf = local_input_buf global_input_buf = local_input_buf
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
...@@ -50,20 +67,25 @@ class MOEScatter(Function): ...@@ -50,20 +67,25 @@ class MOEScatter(Function):
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args (fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1: if world_size > 1:
local_grad_in, = fmoe_cuda.global_gather(global_grad_in, (local_grad_in,) = fmoe_cuda.global_gather(
local_expert_count, global_expert_count, global_grad_in,
local_batch_size, world_size) local_expert_count,
global_expert_count,
local_batch_size,
world_size,
)
else: else:
local_grad_in = global_grad_in local_grad_in = global_grad_in
grad_in, = fmoe_cuda.local_gather(local_grad_in, pos) (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
class MOELinear(Function): class MOELinear(Function):
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = fmoe_cuda.forward(global_input_buf, weight, (global_output_buf,) = fmoe_cuda.forward(
fwd_expert_count) global_input_buf, weight, fwd_expert_count
)
variables = (global_input_buf, weight, fwd_expert_count) variables = (global_input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
...@@ -72,21 +94,33 @@ class MOELinear(Function): ...@@ -72,21 +94,33 @@ class MOELinear(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward( grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count) grad_out, input_buf, weight, fwd_expert_count
)
return grad_inp_buf, grad_weight, None return grad_inp_buf, grad_weight, None
class MOEGather(Function): class MOEGather(Function):
@staticmethod @staticmethod
def forward(ctx, global_output_buf, pos, local_expert_count, def forward(
global_expert_count, local_batch_size, world_size): ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
):
if world_size > 1: if world_size > 1:
local_output_buf, = fmoe_cuda.global_gather(global_output_buf, (local_output_buf,) = fmoe_cuda.global_gather(
local_expert_count, global_expert_count, global_output_buf,
local_batch_size, world_size) local_expert_count,
global_expert_count,
local_batch_size,
world_size,
)
else: else:
local_output_buf = global_output_buf local_output_buf = global_output_buf
output, = fmoe_cuda.local_gather(local_output_buf, pos) (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
...@@ -97,13 +131,15 @@ class MOEGather(Function): ...@@ -97,13 +131,15 @@ class MOEGather(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors pos, local_expert_count, global_expert_count = ctx.saved_tensors
local_batch_size, fwd_batch_size, world_size = ctx.moe_args local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos) (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
if world_size > 1: if world_size > 1:
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf, (global_grad_out_buf,) = fmoe_cuda.global_scatter(
local_expert_count, global_expert_count, grad_out_buf,
fwd_batch_size, world_size) local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else: else:
global_grad_out_buf = grad_out_buf global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None, None, None return global_grad_out_buf, None, None, None, None, None
...@@ -9,8 +9,7 @@ class FMoELinear(nn.Module): ...@@ -9,8 +9,7 @@ class FMoELinear(nn.Module):
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight = nn.Parameter( self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -30,8 +29,9 @@ class FMoENaiveGate(nn.Module): ...@@ -30,8 +29,9 @@ class FMoENaiveGate(nn.Module):
def forward(self, inp): def forward(self, inp):
gate = self.gate(inp) gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, gate_top_k_val, gate_top_k_idx = torch.topk(
largest=True, sorted=False) # [.. x top_k] gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k) gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k # (BxL) x 1 x top_k
...@@ -42,23 +42,38 @@ class FMoENaiveGate(nn.Module): ...@@ -42,23 +42,38 @@ class FMoENaiveGate(nn.Module):
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
(pos, local_expert_count, global_expert_count, fwd_expert_count, (
fwd_batch_size) = moe_prepare_forward(gate, num_expert, world_size) pos,
x = MOEScatter.apply(inp, pos, local_expert_count, global_expert_count, local_expert_count,
fwd_batch_size, world_size) global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
)
for i, l in enumerate(linears): for i, l in enumerate(linears):
if i: if i:
x = activation(x) x = activation(x)
x = l(x, fwd_expert_count) x = l(x, fwd_expert_count)
x = MOEGather.apply(x, pos, local_expert_count, global_expert_count, x = MOEGather.apply(
inp.shape[0], world_size) x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
)
return x return x
class FMoETransformerMLP(nn.Module): class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, def __init__(
world_size=1, activation=torch.nn.functional.gelu, self,
top_k=2, pre_lnorm=False): num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
activation=torch.nn.functional.gelu,
top_k=2,
pre_lnorm=False,
model_parallel_rank=-1,
):
super(FMoETransformerMLP, self).__init__() super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
...@@ -74,8 +89,10 @@ class FMoETransformerMLP(nn.Module): ...@@ -74,8 +89,10 @@ class FMoETransformerMLP(nn.Module):
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k) self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model, self.bias = torch.nn.parameter.Parameter(
dtype=torch.float32)) torch.zeros(d_model, dtype=torch.float32)
)
self.model_parallel_rank = model_parallel_rank
def forward(self, inp): def forward(self, inp):
residual = inp residual = inp
...@@ -85,11 +102,17 @@ class FMoETransformerMLP(nn.Module): ...@@ -85,11 +102,17 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# TODO: merge replication into local_scatter # TODO: merge replication into local_scatter
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k, inp = inp.view(-1, self.d_model).repeat_interleave(
dim=0) # (BxLxtop_k) x d_model repeats=self.top_k, dim=0
x = _fmoe_full_forward(inp, gate_top_k_idx, ) # (BxLxtop_k) x d_model
[self.htoh4, self.h4toh], self.activation, x = _fmoe_full_forward(
self.num_expert, self.world_size) inp,
gate_top_k_idx,
[self.htoh4, self.h4toh],
self.activation,
self.num_expert,
self.world_size,
)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
...@@ -99,4 +122,3 @@ class FMoETransformerMLP(nn.Module): ...@@ -99,4 +122,3 @@ class FMoETransformerMLP(nn.Module):
if not self.pre_lnorm: if not self.pre_lnorm:
output = self.layer_norm(output) output = self.layer_norm(output)
return output, self.bias return output, self.bias
...@@ -2,11 +2,15 @@ from .layers import FMoETransformerMLP ...@@ -2,11 +2,15 @@ from .layers import FMoETransformerMLP
def create_moe_mlp(args): def create_moe_mlp(args):
assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size' assert (
args.num_experts % args.model_parallel_size == 0
), "Num experts should be multiple of mp size"
num_experts = args.num_experts // args.model_parallel_size num_experts = args.num_experts // args.model_parallel_size
fmoe = FMoETransformerMLP(num_experts, fmoe = FMoETransformerMLP(
num_experts,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_size * 4, d_hidden=args.hidden_size * 4,
world_size=args.model_parallel_size) world_size=args.model_parallel_size,
model_parallel_rank=args.model_parallel_rank,
)
return fmoe return fmoe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment