#include #include #include #include #include // CUDA runtime #include #include // CUDA and CUBLAS functions //#include //#include const int num_stream=1024; // std::vector void moe_cuda_forward( torch::Tensor input, // [B x D_model] torch::Tensor gate, // [B x N] torch::Tensor weight, // [N x D_model x D_ffn] torch::Tensor bias // [N x D_ffn] ) { const auto batch_size = input.size(0); const auto num_expert = gate.size(1); const auto d_model = weight.size(1); const auto d_ffn = weight.size(2); auto output = input.new_zeros({batch_size, num_expert, d_ffn}); cublasHandle_t handle; cublasCreate(&handle); cudaStream_t stream[num_stream]; for (size_t i=0; i(), 1, weight.index(gate[i][j]).data_ptr(), d_model, &beta, output[i][j].data_ptr(), 1); } else { printf("only support double!!!\n"); } } } for (size_t i=0; i