#include #include #include #include #include // CUDA runtime #include #include // CUDA and CUBLAS functions //#include #include const int num_stream=512; // std::vector void moe_cuda_forward( torch::Tensor input, // [B x D_model] torch::Tensor gate, // [B x N] torch::Tensor weight, // [N x D_model x D_ffn] torch::Tensor bias // [N x D_ffn] ) { const auto batch_size = input.size(0); const auto num_expert = gate.size(1); const auto d_model = weight.size(1); const auto d_ffn = weight.size(2); printf("b=%d, expert=%d, d_model=%d, d_ffn=%d\n", batch_size, num_expert, d_model, d_ffn); auto output = input.new_zeros({batch_size, num_expert, d_ffn}); cublasHandle_t handle; checkCudaErrors(cublasCreate(&handle)); cudaStream_t stream[num_stream]; for (size_t i=0; i(), 1, weight.index(gate[i][j]).data_ptr(), d_model, &beta, output[i][j].data_ptr(), 1)); } else { printf("only support float!!!\n"); } } } // checkCudaErrors(cudaDeviceSynchronize()); // Record the stop event checkCudaErrors(cudaEventRecord(stop, NULL)); // Wait for the stop event to complete checkCudaErrors(cudaEventSynchronize(stop)); float msecTotal = 0.0f; checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop)); // Compute and print the performance float msecPerMatrixMul = msecTotal / batch_size / num_expert; double flopsPerMatrixMul = 2.0 * (double)d_model * (double)d_ffn; double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f); printf( "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n", gigaFlops, msecPerMatrixMul, flopsPerMatrixMul); // std::cout << output << std::endl; for (size_t i=0; i