Commit 60b93e39 authored by Rick Ho's avatar Rick Ho
Browse files

stream manager across multiple devices

parent 27d8beaa
#include <cuda_runtime.h> #include <unordered_map>
#include <mutex>
#include <cassert> #include <cassert>
#include <thread> #include <thread>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include <helper_cuda.h>
#define SMGR_N_STREAMS 4
cudaStream_t CudaStreamManager::stream(size_t idx) { cudaStream_t CudaStreamManager::stream(size_t idx) {
if (num_expert <= idx) { return this->streams[idx % SMGR_N_STREAMS];
this->setup(idx + 1);
}
return this->streams[idx];
} }
void CudaStreamManager::sync(int i) { cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (i > -1) { return this->handles[idx % SMGR_N_STREAMS];
cudaStreamSynchronize(streams[i]); }
return;
}
for (size_t i = 0; i < this->num_expert; ++i) { void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
} }
void CudaStreamManager::setup(const size_t num_expert, const int device) { void CudaStreamManager::setup(const int device) {
#ifdef MOE_DEBUG checkCudaErrors(cudaSetDevice(device));
printf("setup at device %d\n", device); streams = new cudaStream_t[SMGR_N_STREAMS];
#endif handles = new cublasHandle_t[SMGR_N_STREAMS];
this->num_expert = num_expert; for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
if (device == -1) {
checkCudaErrors(cudaGetDevice(&this->device));
} else {
this->device = device;
}
checkCudaErrors(cudaSetDevice(this->device));
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i)); checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
} }
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
...@@ -3,50 +3,30 @@ ...@@ -3,50 +3,30 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h>
#include <cstdio>
class CudaStreamManager { class CudaStreamManager {
public: public:
size_t num_expert;
int device; int device;
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
public: public:
CudaStreamManager() : num_expert(0), streams(NULL) { CudaStreamManager(int device_): device(device_) {
int current_device; this->setup(device);
checkCudaErrors(cudaGetDevice(&current_device));
#ifdef MOE_DEBUG
printf("constructor at device %d\n", current_device);
#endif
} }
void setup(const size_t num_expert, const int device=-1); void setup(int);
void sync(int=0);
void destroy();
cudaStream_t stream(size_t=0); cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
~CudaStreamManager() { ~CudaStreamManager() {
#ifdef MOE_DEBUG this->destroy();
printf("destructor at device %d\n", device);
#endif
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
} }
void sync(int=-1);
}; };
#define ENSURE_SMGR(__smgr__, __num_expert__) { \ CudaStreamManager* getCudaStreamManager(const int device);
if (__smgr__.num_expert == 0) { \
__smgr__.setup(__num_expert__); \
} \
}
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -18,13 +17,6 @@ ...@@ -18,13 +17,6 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
// #define MOE_DEBUG
// thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager smgr;
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...@@ -35,7 +27,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, ...@@ -35,7 +27,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
} }
} }
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_scatter_kernel(size_t wid, const int* pos, void batch_scatter_kernel(size_t wid, const int* pos,
...@@ -77,8 +68,6 @@ void moe_cuda_expert_count_impl( ...@@ -77,8 +68,6 @@ void moe_cuda_expert_count_impl(
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
delete [] gate; delete [] gate;
delete [] expert_ptr; delete [] expert_ptr;
ENSURE_SMGR(smgr, num_expert);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -87,11 +76,12 @@ void moe_cuda_local_scatter_impl( ...@@ -87,11 +76,12 @@ void moe_cuda_local_scatter_impl(
const int* d_pos, const int* d_pos,
scalar_t* input_buf, scalar_t* input_buf,
const size_t batch_size, const size_t batch_size,
const size_t in_feat) { const size_t in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.stream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf); input_buf);
smgr.sync(0); smgr->sync(1);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -111,11 +101,12 @@ void moe_cuda_local_gather_impl( ...@@ -111,11 +101,12 @@ void moe_cuda_local_gather_impl(
const int* d_pos, const int* d_pos,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
const size_t out_feat) { const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.stream(0)>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output); output);
smgr.sync(0); smgr->sync(1);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -126,7 +117,8 @@ void moe_cuda_forward_impl( ...@@ -126,7 +117,8 @@ void moe_cuda_forward_impl(
scalar_t* output_buf, scalar_t* output_buf,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -134,7 +126,8 @@ void moe_cuda_forward_impl( ...@@ -134,7 +126,8 @@ void moe_cuda_forward_impl(
continue; continue;
} }
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(smgr.handles[i], checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
...@@ -147,7 +140,7 @@ void moe_cuda_forward_impl( ...@@ -147,7 +140,7 @@ void moe_cuda_forward_impl(
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr.sync(); smgr->sync(num_expert);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -161,8 +154,8 @@ void moe_cuda_backward_impl( ...@@ -161,8 +154,8 @@ void moe_cuda_backward_impl(
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert,
ENSURE_SMGR(smgr, num_expert); CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -174,7 +167,8 @@ void moe_cuda_backward_impl( ...@@ -174,7 +167,8 @@ void moe_cuda_backward_impl(
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o // Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i], checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
in_feat, expert_count[i], out_feat, in_feat, expert_count[i], out_feat,
...@@ -186,7 +180,8 @@ void moe_cuda_backward_impl( ...@@ -186,7 +180,8 @@ void moe_cuda_backward_impl(
)); ));
// Backward weight: g_w = i @ g_o // Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i], checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
in_feat, out_feat, expert_count[i], in_feat, out_feat, expert_count[i],
...@@ -199,7 +194,7 @@ void moe_cuda_backward_impl( ...@@ -199,7 +194,7 @@ void moe_cuda_backward_impl(
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr.sync(); smgr->sync(num_expert);
} }
...@@ -229,6 +224,7 @@ std::vector<torch::Tensor> moe_cuda_expert_count( ...@@ -229,6 +224,7 @@ std::vector<torch::Tensor> moe_cuda_expert_count(
std::vector<torch::Tensor> moe_cuda_local_scatter( std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto in_feat = input.size(1); const auto in_feat = input.size(1);
...@@ -241,7 +237,8 @@ std::vector<torch::Tensor> moe_cuda_local_scatter( ...@@ -241,7 +237,8 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
pos.data_ptr<int>(), pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat); in_feat,
smgr);
})); }));
return {input_buf,}; return {input_buf,};
} }
...@@ -249,6 +246,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter( ...@@ -249,6 +246,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
std::vector<torch::Tensor> moe_cuda_local_gather( std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = output_buf.size(0); const auto batch_size = output_buf.size(0);
const auto out_feat = output_buf.size(1); const auto out_feat = output_buf.size(1);
...@@ -261,7 +259,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -261,7 +259,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
pos.data_ptr<int>(), pos.data_ptr<int>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, batch_size,
out_feat); out_feat,
smgr);
})); }));
return {output,}; return {output,};
} }
...@@ -271,6 +270,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -271,6 +270,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count torch::Tensor expert_count
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
...@@ -280,12 +280,6 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -280,12 +280,6 @@ std::vector<torch::Tensor> moe_cuda_forward(
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat); num_expert, in_feat, out_feat);
#endif #endif
/*
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
*/
auto out_options = torch::TensorOptions() auto out_options = torch::TensorOptions()
.device(input_buf.device()) .device(input_buf.device())
.dtype(input_buf.dtype()); .dtype(input_buf.dtype());
...@@ -300,7 +294,8 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -300,7 +294,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
smgr
); );
})); }));
...@@ -313,6 +308,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -313,6 +308,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count torch::Tensor expert_count
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
...@@ -338,7 +334,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -338,7 +334,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
smgr
); );
})); }));
......
...@@ -14,11 +14,12 @@ def perf(): ...@@ -14,11 +14,12 @@ def perf():
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, in_feat).cuda(dev_name) inp = torch.rand(batch_size, in_feat, requires_grad=True).cuda(dev_name)
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), gate = torch.randint(low=0, high=num_expert, size=(batch_size, ),
requires_grad=False).int().cuda(dev_name) requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name) moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
moe.train()
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
...@@ -29,6 +30,7 @@ def perf(): ...@@ -29,6 +30,7 @@ def perf():
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
backt = 0.
maxt = 0. maxt = 0.
sqtot = 0. sqtot = 0.
for i in range(n_runs): for i in range(n_runs):
...@@ -37,14 +39,23 @@ def perf(): ...@@ -37,14 +39,23 @@ def perf():
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp, gate)
te = time.time() te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts tott += te - ts
sqtot += (te - ts)**2 sqtot += (te - ts)**2
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format( print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3, tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops)) (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment