#include #include #include #include #include // CUDA runtime #include #include // CUDA and CUBLAS functions //#include #include const int num_stream=16; // std::vector void moe_cuda_forward( torch::Tensor input, // [B x D_model] torch::Tensor gate, // [B x N] torch::Tensor weight, // [N x D_model x D_ffn] torch::Tensor bias // [N x D_ffn] ) { const auto batch_size = input.size(0); const auto num_expert = gate.size(1); const auto d_model = weight.size(1); const auto d_ffn = weight.size(2); auto output = input.new_zeros({batch_size, num_expert, d_ffn}); std::cout << output << std::endl; cublasHandle_t handle; checkCudaErrors(cublasCreate(&handle)); cudaStream_t stream[num_stream]; for (size_t i=0; i() + i * d_model, // input[i].data_ptr(), 1, weight.index(gate[i][j]).data_ptr(), d_model, &beta, output.data_ptr() + i * num_expert * d_ffn + j * d_ffn, 1); } else { printf("only support float!!!\n"); } } } cudaDeviceSynchronize(); printf("synchronized\n"); for (size_t i=0; i