Commit ef83c893 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

stream manager object instead of pointer

parent e52c0380
......@@ -3,8 +3,10 @@
#include "cuda_stream_manager.h"
thread_local CudaStreamManager* smgr = NULL;
thread_local CudaStreamManager smgr;
/*
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert, device);
......@@ -13,3 +15,4 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int devic
assert(smgr->device == device);
return smgr;
}
*/
......@@ -10,38 +10,44 @@
class CudaStreamManager {
public:
CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
/*
Actually, we will see current_device == device,
which means pytorch always sets the correct device for us.
But for safety, we still manually set device to the desired one.
*/
/*
CudaStreamManager() : num_expert(0), device(0), streams(NULL) {
int current_device;
checkCudaErrors(cudaGetDevice(&current_device));
printf("CudaStreamManager construnctor called, get device %d, set device %d\n", current_device, device);
*/
checkCudaErrors(cudaSetDevice(device));
#ifdef MOE_DEBUG
printf("constructor at device %d\n", current_device);
#endif
}
void setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
printf("setup at device %d\n", device);
#endif
this->num_expert = num_expert;
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
}
~CudaStreamManager() {
#ifdef MOE_DEBUG
printf("destructor at device %d\n", device);
#endif
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
}
checkCudaErrors(cublasDestroy(handle));
delete[] streams;
}
const size_t num_expert;
const int device;
size_t num_expert;
int device;
cublasHandle_t handle;
cudaStream_t* streams;
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER
......@@ -18,6 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
thread_local CudaStreamManager smgr;
template <typename scalar_t>
__global__
......@@ -39,12 +40,9 @@ void moe_cuda_forward_impl(
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb,
const int device) {
cublasOperation_t transb) {
auto* h = getCudaStreamManager(num_expert, device);
checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams)));
// setup Aarray, Barray and Carray
std::vector<const scalar_t*> aptrs;
......@@ -70,11 +68,11 @@ void moe_cuda_forward_impl(
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0,
*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
*(smgr.streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(h->handle,
checkCudaErrors(cublasXgemmBatched(smgr.handle,
CUBLAS_OP_N,
transb,
1, out_feat, in_feat,
......@@ -85,7 +83,7 @@ void moe_cuda_forward_impl(
Carray, 1,
batch_size));
checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams)));
checkCudaErrors(cudaFree(Aarray));
checkCudaErrors(cudaFree(Barray));
checkCudaErrors(cudaFree(Carray));
......@@ -100,17 +98,14 @@ void moe_cuda_grad_weight(
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
const int device) {
const size_t num_expert) {
auto h = getCudaStreamManager(num_expert, device);
int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handle,
checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams + gate_host[i])));
checkCudaErrors(cublasXgemm(smgr.handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
......@@ -126,7 +121,7 @@ void moe_cuda_grad_weight(
out_feat));
}
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
}
delete[] gate_host;
}
......@@ -143,7 +138,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
#ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
int device = device_of(input).value().index();
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
......@@ -156,8 +154,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
in_feat,
out_feat,
num_expert,
CUBLAS_OP_T,
device
CUBLAS_OP_T
);
}));
......@@ -178,7 +175,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
int device = device_of(input).value().index();
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
......@@ -193,8 +194,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N,
device
CUBLAS_OP_N
);
}));
......@@ -207,8 +207,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size,
in_feat,
out_feat,
num_expert,
device
num_expert
);
}));
......
......@@ -11,7 +11,7 @@ setup(
name='moe_cuda',
sources=[
'moe.cpp',
'cuda_stream_manager.cpp',
# 'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment