Commit 8cff6ad7 authored by Rick Ho's avatar Rick Ho
Browse files

tide up C code

parent 2ba58797
...@@ -4,10 +4,24 @@ project(moe) ...@@ -4,10 +4,24 @@ project(moe)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
include_directories("/home/jiezhong/miniconda3/include/python3.8" if(NOT PYTHON_INCLUDE)
"/usr/local/cuda/include" set(PYTHON_INCLUDE "/home/jiezhong/miniconda3/include/python3.8")
"/usr/local/cuda/samples/common/inc") endif()
add_executable(moe moe.cpp)
if(NOT CUDA_HOME)
set(CUDA_HOME "/usr/local/cuda")
endif()
if(NOT CUDA_SAMPLE_INCLUDE)
set(CUDA_SAMPLE_INCLUDE "/usr/local/cuda/samples/common/inc")
endif()
include_directories(
"${PYTHON_INCLUDE}"
"${CUDA_HOME}/include"
"${CUDA_SAMPLE_INCLUDE}"
)
add_executable(moe moe.cpp cuda_stream_manager.cpp)
target_link_libraries(moe target_link_libraries(moe
"${TORCH_LIBRARIES}") "${TORCH_LIBRARIES}")
set_property(TARGET moe PROPERTY CXX_STANDARD 14) set_property(TARGET moe PROPERTY CXX_STANDARD 14)
......
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *Aarray[], int lda,
const float *Barray[], int ldb,
const float *beta,
float *Carray[], int ldc,
int batchCount) {
return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *Aarray[], int lda,
const double *Barray[], int ldb,
const double *beta,
double *Carray[], int ldc,
int batchCount) {
return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *Aarray[], int lda,
const __half *Barray[], int ldb,
const __half *beta,
__half *Carray[], int ldc,
int batchCount) {
return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *A, int lda,
const double *B, int ldb,
const double *beta,
double *C, int ldc) {
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *A, int lda,
const __half *B, int ldb,
const __half *beta,
__half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
#endif // CUBLAS_WRAPPER_H
#include <cassert>
#include "cuda_stream_manager.h"
CudaStreamManager* smgr = NULL;
CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert);
}
assert(smgr->num_expert == num_expert);
return smgr;
}
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
class CudaStreamManager {
public:
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
}
~CudaStreamManager() {
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
}
checkCudaErrors(cublasDestroy(handle));
}
const size_t num_expert;
cublasHandle_t handle;
cudaStream_t* streams;
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert);
#endif // CUDA_STREAM_MANAGER
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cassert>
#include <cuda.h> #include <cuda.h>
...@@ -13,37 +12,10 @@ ...@@ -13,37 +12,10 @@
// #include "timer.hh" // #include "timer.hh"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
class Helper {
public:
Helper(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
}
~Helper() {
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
}
checkCudaErrors(cublasDestroy(handle));
}
const size_t num_expert;
cublasHandle_t handle;
cudaStream_t* streams;
};
Helper* helper = NULL; #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Helper* getHelper(const size_t num_expert) {
if (!helper) {
helper = new Helper(num_expert);
}
assert(helper->num_expert == num_expert);
return helper;
}
template <typename scalar_t> template <typename scalar_t>
...@@ -56,79 +28,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c ...@@ -56,79 +28,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c
} }
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *Aarray[], int lda,
const float *Barray[], int ldb,
const float *beta,
float *Carray[], int ldc,
int batchCount) {
return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *Aarray[], int lda,
const double *Barray[], int ldb,
const double *beta,
double *Carray[], int ldc,
int batchCount) {
return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *Aarray[], int lda,
const __half *Barray[], int ldb,
const __half *beta,
__half *Carray[], int ldc,
int batchCount) {
return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *A, int lda,
const double *B, int ldb,
const double *beta,
double *C, int ldc) {
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *A, int lda,
const __half *B, int ldb,
const __half *beta,
__half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input, const scalar_t* input,
...@@ -141,7 +40,7 @@ void moe_cuda_forward_impl( ...@@ -141,7 +40,7 @@ void moe_cuda_forward_impl(
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb) { cublasOperation_t transb) {
Helper* h = getHelper(num_expert); auto* h = getCudaStreamManager(num_expert);
checkCudaErrors(cublasSetStream(h->handle, *(h->streams))); checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
...@@ -160,25 +59,29 @@ void moe_cuda_forward_impl( ...@@ -160,25 +59,29 @@ void moe_cuda_forward_impl(
aptrs.push_back(input + in_feat * i); aptrs.push_back(input + in_feat * i);
cptrs.push_back(output + out_feat * i); cptrs.push_back(output + out_feat * i);
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice)); scalar_t*), cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size * sizeof(scalar_t*), cudaMemcpyHostToDevice)); // checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
dim3 griddim(CEIL(batch_size, 256)); checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
dim3 blockdim(256); sizeof(scalar_t*), cudaMemcpyHostToDevice));
generate_ptr_offset_kernel<<<griddim, blockdim, 0, *(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
scalar_t alpha = 1, beta = 0; generate_ptr_offset_kernel<<<griddim, blockdim, 0,
checkCudaErrors(cublasXgemmBatched(h->handle, *(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
CUBLAS_OP_N,
transb, scalar_t alpha = 1, beta = 0;
1, out_feat, in_feat,
&alpha, checkCudaErrors(cublasXgemmBatched(h->handle,
Aarray, 1, CUBLAS_OP_N,
Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat, transb,
&beta, 1, out_feat, in_feat,
Carray, 1, &alpha,
batch_size)); Aarray, 1,
Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
&beta,
Carray, 1,
batch_size));
checkCudaErrors(cudaStreamSynchronize(*(h->streams))); checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
} }
...@@ -194,7 +97,7 @@ void moe_cuda_grad_weight( ...@@ -194,7 +97,7 @@ void moe_cuda_grad_weight(
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert) {
Helper* h = getHelper(num_expert); auto h = getCudaStreamManager(num_expert);
int* gate_host = new int[batch_size]; int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1; scalar_t alpha = 1, beta = 1;
...@@ -231,7 +134,9 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -231,7 +134,9 @@ std::vector<torch::Tensor> moe_cuda_forward(
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
...@@ -338,4 +243,4 @@ int main() { ...@@ -338,4 +243,4 @@ int main() {
double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum; double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
printf("%.3lf TFLOPs\n", tflops); printf("%.3lf TFLOPs\n", tflops);
} }
*/ */
\ No newline at end of file
from moe import MOELayer from moe import MOELayer
import torch import torch
import time import time
import sys
def perf(): def perf():
batch_size = 128 batch_size = int(sys.argv[1])
in_feat = 1024 in_feat = int(sys.argv[2])
out_feat = 4096 out_feat = int(sys.argv[3])
num_expert = 4 num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
......
#!/bin/bash
export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python moe.py
elif [ .$1 = '.test_all' ]
then
for bs in 4 16 64
do
for inf in 1024 4096
do
for ouf in 1024 4096
do
for nexp in 4 16 64
do
echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp
done
done
done
done
else
python $@
fi
...@@ -11,6 +11,7 @@ setup( ...@@ -11,6 +11,7 @@ setup(
name='moe_cuda', name='moe_cuda',
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)], extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment