Commit bb3c8966 authored by Rick Ho's avatar Rick Ho
Browse files

reconstruct fmoe cuda code

parent 4c90e6e8
...@@ -111,7 +111,7 @@ class MOELinear(Function): ...@@ -111,7 +111,7 @@ class MOELinear(Function):
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, weight, fwd_expert_count):
(global_output_buf,) = fmoe_cuda.forward( (global_output_buf,) = fmoe_cuda.linear_forward(
global_input_buf, weight, fwd_expert_count global_input_buf, weight, fwd_expert_count
) )
variables = (global_input_buf, weight, fwd_expert_count) variables = (global_input_buf, weight, fwd_expert_count)
...@@ -121,7 +121,7 @@ class MOELinear(Function): ...@@ -121,7 +121,7 @@ class MOELinear(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward( grad_inp_buf, grad_weight = fmoe_cuda.linear_backward(
grad_out, input_buf, weight, fwd_expert_count grad_out, input_buf, weight, fwd_expert_count
) )
return grad_inp_buf, grad_weight, None return grad_inp_buf, grad_weight, None
......
...@@ -7,7 +7,7 @@ cxx_flags = [] ...@@ -7,7 +7,7 @@ cxx_flags = []
ext_libs = [] ext_libs = []
if os.environ.get('USE_NCCL', '0') == '1': if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL') cxx_flags.append('-DFMOE_USE_NCCL')
ext_libs.append('nccl') ext_libs.append('nccl')
...@@ -25,11 +25,11 @@ if __name__ == '__main__': ...@@ -25,11 +25,11 @@ if __name__ == '__main__':
CUDAExtension( CUDAExtension(
name='fmoe_cuda', name='fmoe_cuda',
sources=[ sources=[
'cuda/moe.cpp', 'cuda/stream_manager.cpp',
'cuda/cuda_stream_manager.cpp', 'cuda/local_exchange.cu',
'cuda/moe_compute_kernel.cu', 'cuda/global_exchange.cpp',
'cuda/moe_comm_kernel.cu', 'cuda/parallel_linear.cpp',
'cuda/moe_fused_kernel.cu', 'cuda/fmoe_cuda.cpp',
], ],
extra_compile_args={ extra_compile_args={
'cxx': cxx_flags, 'cxx': cxx_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment