Commit ef83c893 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

stream manager object instead of pointer

parent e52c0380
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
thread_local CudaStreamManager* smgr = NULL; thread_local CudaStreamManager smgr;
/*
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) { CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) { if (!smgr) {
smgr = new CudaStreamManager(num_expert, device); smgr = new CudaStreamManager(num_expert, device);
...@@ -13,3 +15,4 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int devic ...@@ -13,3 +15,4 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int devic
assert(smgr->device == device); assert(smgr->device == device);
return smgr; return smgr;
} }
*/
...@@ -10,38 +10,44 @@ ...@@ -10,38 +10,44 @@
class CudaStreamManager { class CudaStreamManager {
public: public:
CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) { CudaStreamManager() : num_expert(0), device(0), streams(NULL) {
/*
Actually, we will see current_device == device,
which means pytorch always sets the correct device for us.
But for safety, we still manually set device to the desired one.
*/
/*
int current_device; int current_device;
checkCudaErrors(cudaGetDevice(&current_device)); checkCudaErrors(cudaGetDevice(&current_device));
printf("CudaStreamManager construnctor called, get device %d, set device %d\n", current_device, device); #ifdef MOE_DEBUG
*/ printf("constructor at device %d\n", current_device);
checkCudaErrors(cudaSetDevice(device)); #endif
}
void setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
printf("setup at device %d\n", device);
#endif
this->num_expert = num_expert;
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle)); checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i)); checkCudaErrors(cudaStreamCreate(streams+i));
} }
} }
~CudaStreamManager() { ~CudaStreamManager() {
#ifdef MOE_DEBUG
printf("destructor at device %d\n", device);
#endif
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i))); checkCudaErrors(cudaStreamDestroy(*(streams+i)));
} }
checkCudaErrors(cublasDestroy(handle)); checkCudaErrors(cublasDestroy(handle));
delete[] streams; delete[] streams;
} }
const size_t num_expert; size_t num_expert;
const int device; int device;
cublasHandle_t handle; cublasHandle_t handle;
cudaStream_t* streams; cudaStream_t* streams;
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device); // CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
thread_local CudaStreamManager smgr;
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -39,12 +40,9 @@ void moe_cuda_forward_impl( ...@@ -39,12 +40,9 @@ void moe_cuda_forward_impl(
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb, cublasOperation_t transb) {
const int device) {
auto* h = getCudaStreamManager(num_expert, device); checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams)));
checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
// setup Aarray, Barray and Carray // setup Aarray, Barray and Carray
std::vector<const scalar_t*> aptrs; std::vector<const scalar_t*> aptrs;
...@@ -70,11 +68,11 @@ void moe_cuda_forward_impl( ...@@ -70,11 +68,11 @@ void moe_cuda_forward_impl(
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256); dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0, generate_ptr_offset_kernel<<<griddim, blockdim, 0,
*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray); *(smgr.streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(h->handle, checkCudaErrors(cublasXgemmBatched(smgr.handle,
CUBLAS_OP_N, CUBLAS_OP_N,
transb, transb,
1, out_feat, in_feat, 1, out_feat, in_feat,
...@@ -85,7 +83,7 @@ void moe_cuda_forward_impl( ...@@ -85,7 +83,7 @@ void moe_cuda_forward_impl(
Carray, 1, Carray, 1,
batch_size)); batch_size));
checkCudaErrors(cudaStreamSynchronize(*(h->streams))); checkCudaErrors(cudaStreamSynchronize(*(smgr.streams)));
checkCudaErrors(cudaFree(Aarray)); checkCudaErrors(cudaFree(Aarray));
checkCudaErrors(cudaFree(Barray)); checkCudaErrors(cudaFree(Barray));
checkCudaErrors(cudaFree(Carray)); checkCudaErrors(cudaFree(Carray));
...@@ -100,17 +98,14 @@ void moe_cuda_grad_weight( ...@@ -100,17 +98,14 @@ void moe_cuda_grad_weight(
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert) {
const int device) {
auto h = getCudaStreamManager(num_expert, device);
int* gate_host = new int[batch_size]; int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1; scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i]))); checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handle, checkCudaErrors(cublasXgemm(smgr.handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
out_feat, out_feat,
...@@ -126,7 +121,7 @@ void moe_cuda_grad_weight( ...@@ -126,7 +121,7 @@ void moe_cuda_grad_weight(
out_feat)); out_feat));
} }
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamSynchronize(*(h->streams + i))); checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
} }
delete[] gate_host; delete[] gate_host;
} }
...@@ -143,7 +138,10 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -143,7 +138,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif #endif
int device = device_of(input).value().index(); const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
...@@ -156,8 +154,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -156,8 +154,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
CUBLAS_OP_T, CUBLAS_OP_T
device
); );
})); }));
...@@ -178,7 +175,11 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -178,7 +175,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif #endif
int device = device_of(input).value().index(); const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
...@@ -193,8 +194,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -193,8 +194,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
out_feat, out_feat,
in_feat, in_feat,
num_expert, num_expert,
CUBLAS_OP_N, CUBLAS_OP_N
device
); );
})); }));
...@@ -207,8 +207,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -207,8 +207,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert
device
); );
})); }));
......
...@@ -11,7 +11,7 @@ setup( ...@@ -11,7 +11,7 @@ setup(
name='moe_cuda', name='moe_cuda',
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'cuda_stream_manager.cpp', # 'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)], extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment