#include "parallel_linear.h" #include "utils/fmoe_utils.h" #include std::vector _linear_forward( torch::Tensor input_buf, torch::Tensor weight, torch::Tensor expert_count ) { CHECK_INPUT(input_buf); CHECK_INPUT(weight); auto smgr = getCudaStreamManager(input_buf.device().index()); const auto batch_size = input_buf.size(0); const auto num_expert = weight.size(0); const auto out_feat = weight.size(1); const auto in_feat = weight.size(2); #ifdef FMOE_DEBUG printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", num_expert, in_feat, out_feat); #endif auto out_options = torch::TensorOptions() .device(input_buf.device()) .dtype(input_buf.dtype()); auto output = torch::empty({batch_size, out_feat}, out_options); AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "fmoe_linear_forward", ([&] { fmoe_cuda_forward_impl( input_buf.data_ptr(), weight.data_ptr(), expert_count.data_ptr(), output.data_ptr(), in_feat, out_feat, num_expert, smgr ); })); return {output, }; } std::vector _linear_backward( torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count ) { CHECK_INPUT(grad_output_buf); CHECK_INPUT(input_buf); CHECK_INPUT(weight); auto smgr = getCudaStreamManager(input_buf.device().index()); const auto batch_size = input_buf.size(0); const auto num_expert = weight.size(0); const auto out_feat = weight.size(1); const auto in_feat = weight.size(2); #ifdef FMOE_DEBUG printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, " "out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); #endif auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "ffmoe_linear_backward", ([&] { fmoe_cuda_backward_impl( grad_output_buf.data_ptr(), input_buf.data_ptr(), weight.data_ptr(), expert_count.data_ptr(), grad_input_buf.data_ptr(), grad_weight.data_ptr(), batch_size, in_feat, out_feat, num_expert, smgr ); })); return {grad_input_buf, grad_weight}; }