"server/text_generation_server/models/flash_starcoder2.py" did not exist on "e3e487dc711449c23826cfe1d74786f71309d6bd"
Commit 5e0af68d authored by Rick Ho's avatar Rick Ho
Browse files

split fmoe functions

parent ad07f07a
import torch
from torch.autograd import Function
import fmoe_cuda
def moe_prepare_forward(gate, num_expert, world_size):
fmoe_cuda.ensure_nccl(torch.distributed.distributed_c10d._default_pg, gate)
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item())
return (pos, local_expert_count.cpu(), global_expert_count.cpu(),
fwd_expert_count.cpu(), fwd_batch_size)
class MOEScatter(Function):
@staticmethod
def forward(ctx, inp, pos, local_expert_count, global_expert_count,
fwd_batch_size, world_size):
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
if world_size > 1:
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
else:
global_input_buf = local_input_buf
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
return global_input_buf
@staticmethod
def backward(ctx, global_grad_in):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensors
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1:
local_grad_in, = moe_cuda.global_gather(global_grad_out,
local_expert_count, global_expert_count,
local_batch_size, world_size)
else:
local_grad_in = global_grad_in
grad_in, = moe_cuda.local_scatter(local_grad_in, pos)
return grad_in, None, None, None, None, None
class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
variables = (input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = ome_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count)
return grad_inp_buf, grad_weight, None
class MOEGather(Function):
@staticmethod
def forward(ctx, global_output_buf, pos, local_expert_count,
global_expert_count, local_batch_size, world_size):
if world_size > 1:
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
else:
local_output_buf = global_output_buf
output, = moe_cuda.local_scatter(local_output_buf, pos)
ctx.moe_args = fwd_batch_size, world_size
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = moe_cuda.local_gather(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment