#include #include #include #include #include #include #include #include #include // #include "timer.hh" #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) template __global__ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) { size_t idx = threadIdx.x + blockDim.x * blockIdx.x; if (idx < n) { ptrs[idx] = base + stride * offset[idx]; } } inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount) { return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); } inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount) { return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); } inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half *alpha, const __half *Aarray[], int lda, const __half *Barray[], int ldb, const __half *beta, __half *Carray[], int ldc, int batchCount) { return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); } template void moe1_cuda_forward_impl( const scalar_t* input, const int* gate, const scalar_t* weight, scalar_t* output, const size_t batch_size, const size_t top_k, const size_t in_feat, const size_t out_feat) { cublasHandle_t handle; cudaStream_t st; cudaStreamCreate(&st); checkCudaErrors(cublasCreate(&handle)); checkCudaErrors(cublasSetStream(handle, st)); // setup Aarray, Barray and Carray std::vector aptrs; std::vector cptrs; const scalar_t **Aarray; const scalar_t **Barray; scalar_t **Carray; checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*) * top_k)); checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*) * top_k)); checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*) * top_k)); for (size_t i=0; i>>(batch_size * top_k, weight, out_feat * in_feat, gate, Barray); scalar_t alpha = 1, beta = 0; checkCudaErrors(cublasXgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_T, 1, out_feat, in_feat, &alpha, Aarray, 1, Barray, out_feat, &beta, Carray, 1, batch_size * top_k)); checkCudaErrors(cudaStreamSynchronize(st)); checkCudaErrors(cudaStreamDestroy(st)); checkCudaErrors(cublasDestroy(handle)); } std::vector moe1_cuda_forward( torch::Tensor input, torch::Tensor gate, torch::Tensor weight) { const auto batch_size = input.size(0); const auto top_k = gate.size(1); const auto num_expert = weight.size(0); const auto out_feat = weight.size(1); const auto in_feat = weight.size(2); // printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k); auto output = input.new_zeros({batch_size, top_k, out_feat}); AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_forward_cuda", ([&] { moe1_cuda_forward_impl( input.data_ptr(), gate.data_ptr(), weight.data_ptr(), output.data_ptr(), batch_size, top_k, in_feat, out_feat ); })); return {output, }; } /* int main() { typedef float data_t; size_t batch_size = 4096; size_t top_k = 2; size_t num_expert = 128; size_t in_feat = 1024; size_t out_feat = 4096; data_t *input, *weight; data_t *output; size_t *gate; checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t))); checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t))); checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t))); checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t))); size_t nt = 16; double tsum = 0, tmax = 0; size_t *gate_host = new size_t[batch_size * top_k]; for (size_t i=0; i(input, gate, weight, output, batch_size, top_k, in_feat, out_feat); for (size_t i=0; i(input, gate, weight, output, batch_size, top_k, in_feat, out_feat); timestamp(end); auto t = getDuration(start, end); tsum += t; if (t > tmax) tmax = t; } printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6); double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum; printf("%.3lf TFLOPs\n", tflops); } */