"vscode:/vscode.git/clone" did not exist on "804d9f2e4c498519554cdf8dcbf7bcf2a1c90467"
Commit d0f07ff7 authored by Rick Ho's avatar Rick Ho
Browse files

basic megatron support frame

parent 832385c2
from torch import nn
from .moe import FFFN
def create_moe_mlp(args):
assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
num_experts = args.num_experts // args.model_parallel_size
fmoe = FFFN(num_experts, in_feat=args.hidden_size,
hidden_feat=args.hidden_size * 4, out_feat=args.hidden_size,
world_size = args.model_parallel_size)
return fmoe
......@@ -26,6 +26,23 @@ class FMoE(nn.Module):
return moe(inp, gate.int(), self.weight, self.world_size)
class FFFN(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096,
out_feat=1024, world_size=None, activation=torch.nn.functional.gelu):
super(FFFN, self).__init__()
self.htoh4 = FMoE(num_expert, in_feat, hidden_feat,
world_size=world_size)
self.activation = activation
self.h4toh = FMoE(num_expert, hidden_feat, out_feat,
world_size=world_size)
def forward(self, inp, gate):
x = self.htoh4(inp)
x = self.activation(x)
x = self.h4toh(x)
return x
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
......
import torch
from torch.autograd import Function
import moe_cuda
import fmoe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
input_buf, = fmoe_cuda.local_scatter(inp, pos)
output_buf, = fmoe_cuda.forward(input_buf, weight, expert_count)
output = fmoe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
......@@ -20,10 +20,10 @@ class MOELocal(Function):
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
grad_inp, = fmoe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
......@@ -33,20 +33,20 @@ class MOEGlobal(Function):
def forward(ctx, inp, gate, weight, world_size):
num_expert = weight.shape[0]
local_expert_count, pos = moe_cuda.expert_count(gate,
local_expert_count, pos = fmoe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
global_expert_count, fwd_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
local_output_buf, global_input_buf = moe_cuda.global_fused_forward(
local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
local_input_buf, weight,
local_expert_count, global_expert_count,
fwd_batch_size, inp.shape[0], world_size)
output, = moe_cuda.local_gather(local_output_buf, pos)
output, = fmoe_cuda.local_gather(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
......@@ -63,18 +63,18 @@ class MOEGlobal(Function):
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_inp_buf, grad_weight = fmoe_cuda.backward(
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
local_grad_inp_buf, = fmoe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)
grad_inp, = fmoe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight, None
......
......@@ -12,8 +12,8 @@ if os.environ.get('USE_NCCL', '0') == '1':
if __name__ == '__main__':
setuptools.setup(
name='fmoe_cuda',
packages=setuptools.find_packages(),
name='fmoe',
packages=['fmoe'],
ext_modules=[
CUDAExtension(
name='fmoe_cuda',
......@@ -30,6 +30,7 @@ if __name__ == '__main__':
}
)
],
version='0.0.1',
cmdclass={
'build_ext': BuildExtension
})
......@@ -159,6 +159,7 @@ def test_dp():
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(sys.argv) >= 2:
task = sys.argv[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment