#include #include #include #include #include // CUDA runtime #include #include // CUDA and CUBLAS functions //#include #include const int num_stream=512; inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount) { return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount) } inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount) { return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount) } inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half *alpha, const __half *Aarray[], int lda, const __half *Barray[], int ldb, const __half *beta, _half *Carray[], int ldc, int batchCount) { return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount) } template void moe_cuda_forward_impl( const scalar_t* input, const size_t* gate, const scalar_t* weight, scalar_t* output, size_t batch_size, size_t top_k, size_t in_feat, size_t out_feat) { cublasHandle_t handle; checkCudaErrors(cublasCreate(&handle)); // setup Aarray, Barray and Carray std::vector aptrs, bptrs, cptrs; scalar_t **ptrs; checkCudaErrors(cudaMalloc(&ptrs, batch_size * sizeof(scalar_t*) * top_k * 3)); for (size_t i=0; i( input.data_ptr(), gate.data_ptr(), weight.data_ptr(), output.data_ptr(), batch_size, top_k, in_feat, out_feat ); })); cublasHandle_t handle; checkCudaErrors(cublasCreate(&handle)); cudaStream_t stream[num_stream]; for (size_t i=0; i(), 1, weight.index(gate[i][j]).data_ptr(), d_model, &beta, output[i][j].data_ptr(), 1)); } else { printf("only support float!!!\n"); } } } // checkCudaErrors(cudaDeviceSynchronize()); // Record the stop event checkCudaErrors(cudaEventRecord(stop, NULL)); // Wait for the stop event to complete checkCudaErrors(cudaEventSynchronize(stop)); float msecTotal = 0.0f; checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop)); // Compute and print the performance float msecPerMatrixMul = msecTotal / batch_size / num_expert; double flopsPerMatrixMul = 2.0 * (double)d_model * (double)d_ffn; double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f); printf( "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n", gigaFlops, msecPerMatrixMul, flopsPerMatrixMul); // std::cout << output << std::endl; for (size_t i=0; i void moe_cuda_forward_v1( torch::Tensor input, // [B x D_model] torch::Tensor gate, // [B x N] torch::Tensor weight, // [N x D_model x D_ffn] torch::Tensor bias // [N x D_ffn] ) { const auto batch_size = input.size(0); const auto num_expert = gate.size(1); const auto d_model = weight.size(1); const auto d_ffn = weight.size(2); printf("b=%d, expert=%d, d_model=%d, d_ffn=%d\n", batch_size, num_expert, d_model, d_ffn); auto output = input.new_zeros({batch_size, num_expert, d_ffn}); cublasHandle_t handle; checkCudaErrors(cublasCreate(&handle)); cudaStream_t stream[num_stream]; for (size_t i=0; i(), 1, weight.index(gate[i][j]).data_ptr(), d_model, &beta, output[i][j].data_ptr(), 1)); } else { printf("only support float!!!\n"); } } } // checkCudaErrors(cudaDeviceSynchronize()); // Record the stop event checkCudaErrors(cudaEventRecord(stop, NULL)); // Wait for the stop event to complete checkCudaErrors(cudaEventSynchronize(stop)); float msecTotal = 0.0f; checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop)); // Compute and print the performance float msecPerMatrixMul = msecTotal / batch_size / num_expert; double flopsPerMatrixMul = 2.0 * (double)d_model * (double)d_ffn; double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f); printf( "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n", gigaFlops, msecPerMatrixMul, flopsPerMatrixMul); // std::cout << output << std::endl; for (size_t i=0; i