"vscode:/vscode.git/clone" did not exist on "483182fc3adbaa2b3c0a7b7c91b026c8efdf53d9"
Commit 60b93e39 authored by Rick Ho's avatar Rick Ho
Browse files

stream manager across multiple devices

parent 27d8beaa
#include <cuda_runtime.h>
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
#define SMGR_N_STREAMS 4
cudaStream_t CudaStreamManager::stream(size_t idx) {
if (num_expert <= idx) {
this->setup(idx + 1);
}
return this->streams[idx];
return this->streams[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int i) {
if (i > -1) {
cudaStreamSynchronize(streams[i]);
return;
}
for (size_t i = 0; i < this->num_expert; ++i) {
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
printf("setup at device %d\n", device);
#endif
this->num_expert = num_expert;
if (device == -1) {
checkCudaErrors(cudaGetDevice(&this->device));
} else {
this->device = device;
}
checkCudaErrors(cudaSetDevice(this->device));
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
void CudaStreamManager::setup(const int device) {
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
......@@ -3,50 +3,30 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <cstdio>
class CudaStreamManager {
public:
size_t num_expert;
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
public:
CudaStreamManager() : num_expert(0), streams(NULL) {
int current_device;
checkCudaErrors(cudaGetDevice(&current_device));
#ifdef MOE_DEBUG
printf("constructor at device %d\n", current_device);
#endif
CudaStreamManager(int device_): device(device_) {
this->setup(device);
}
void setup(const size_t num_expert, const int device=-1);
void setup(int);
void sync(int=0);
void destroy();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
~CudaStreamManager() {
#ifdef MOE_DEBUG
printf("destructor at device %d\n", device);
#endif
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
this->destroy();
}
void sync(int=-1);
};
#define ENSURE_SMGR(__smgr__, __num_expert__) { \
if (__smgr__.num_expert == 0) { \
__smgr__.setup(__num_expert__); \
} \
}
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
CudaStreamManager* getCudaStreamManager(const int device);
#endif // CUDA_STREAM_MANAGER
......@@ -4,10 +4,9 @@
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -18,13 +17,6 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
// #define MOE_DEBUG
// thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager smgr;
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
......@@ -35,7 +27,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
}
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const int* pos,
......@@ -77,8 +68,6 @@ void moe_cuda_expert_count_impl(
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
ENSURE_SMGR(smgr, num_expert);
}
template <typename scalar_t>
......@@ -87,11 +76,12 @@ void moe_cuda_local_scatter_impl(
const int* d_pos,
scalar_t* input_buf,
const size_t batch_size,
const size_t in_feat) {
const size_t in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.stream(0)>>>(in_feat, d_pos, input,
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr.sync(0);
smgr->sync(1);
}
template <typename scalar_t>
......@@ -111,11 +101,12 @@ void moe_cuda_local_gather_impl(
const int* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat) {
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.stream(0)>>>(out_feat, d_pos, output_buf,
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr.sync(0);
smgr->sync(1);
}
template <typename scalar_t>
......@@ -126,7 +117,8 @@ void moe_cuda_forward_impl(
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert) {
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
......@@ -134,7 +126,8 @@ void moe_cuda_forward_impl(
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(smgr.handles[i],
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
......@@ -147,7 +140,7 @@ void moe_cuda_forward_impl(
ptr += expert_count[i];
}
smgr.sync();
smgr->sync(num_expert);
}
template <typename scalar_t>
......@@ -161,8 +154,8 @@ void moe_cuda_backward_impl(
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert) {
ENSURE_SMGR(smgr, num_expert);
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
......@@ -174,7 +167,8 @@ void moe_cuda_backward_impl(
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i],
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
......@@ -186,7 +180,8 @@ void moe_cuda_backward_impl(
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i],
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
......@@ -199,7 +194,7 @@ void moe_cuda_backward_impl(
ptr += expert_count[i];
}
smgr.sync();
smgr->sync(num_expert);
}
......@@ -229,6 +224,7 @@ std::vector<torch::Tensor> moe_cuda_expert_count(
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = input.size(0);
const auto in_feat = input.size(1);
......@@ -241,7 +237,8 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat);
in_feat,
smgr);
}));
return {input_buf,};
}
......@@ -249,6 +246,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = output_buf.size(0);
const auto out_feat = output_buf.size(1);
......@@ -261,7 +259,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
pos.data_ptr<int>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat);
out_feat,
smgr);
}));
return {output,};
}
......@@ -271,6 +270,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor weight,
torch::Tensor expert_count
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
......@@ -280,12 +280,6 @@ std::vector<torch::Tensor> moe_cuda_forward(
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
/*
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
*/
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
......@@ -300,7 +294,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
output.data_ptr<scalar_t>(),
in_feat,
out_feat,
num_expert
num_expert,
smgr
);
}));
......@@ -313,6 +308,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
......@@ -338,7 +334,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size,
in_feat,
out_feat,
num_expert
num_expert,
smgr
);
}));
......
......@@ -14,11 +14,12 @@ def perf():
num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, in_feat).cuda(dev_name)
inp = torch.rand(batch_size, in_feat, requires_grad=True).cuda(dev_name)
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ),
requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
moe.train()
o = moe(inp, gate)
o = moe(inp, gate)
......@@ -29,6 +30,7 @@ def perf():
n_runs = 16
tott = 0.
backt = 0.
maxt = 0.
sqtot = 0.
for i in range(n_runs):
......@@ -37,14 +39,23 @@ def perf():
ts = time.time()
o = moe(inp, gate)
te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment