Commit bb3c8966 authored by Rick Ho's avatar Rick Ho
Browse files

reconstruct fmoe cuda code

parent 4c90e6e8
......@@ -111,7 +111,7 @@ class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
(global_output_buf,) = fmoe_cuda.forward(
(global_output_buf,) = fmoe_cuda.linear_forward(
global_input_buf, weight, fwd_expert_count
)
variables = (global_input_buf, weight, fwd_expert_count)
......@@ -121,7 +121,7 @@ class MOELinear(Function):
@staticmethod
def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_inp_buf, grad_weight = fmoe_cuda.linear_backward(
grad_out, input_buf, weight, fwd_expert_count
)
return grad_inp_buf, grad_weight, None
......
......@@ -7,7 +7,7 @@ cxx_flags = []
ext_libs = []
if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL')
cxx_flags.append('-DFMOE_USE_NCCL')
ext_libs.append('nccl')
......@@ -25,11 +25,11 @@ if __name__ == '__main__':
CUDAExtension(
name='fmoe_cuda',
sources=[
'cuda/moe.cpp',
'cuda/cuda_stream_manager.cpp',
'cuda/moe_compute_kernel.cu',
'cuda/moe_comm_kernel.cu',
'cuda/moe_fused_kernel.cu',
'cuda/stream_manager.cpp',
'cuda/local_exchange.cu',
'cuda/global_exchange.cpp',
'cuda/parallel_linear.cpp',
'cuda/fmoe_cuda.cpp',
],
extra_compile_args={
'cxx': cxx_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment