Commit 49732231 authored by Rick Ho's avatar Rick Ho
Browse files

split operators and make forward run

parent 143e21cc
...@@ -13,3 +13,24 @@ void CudaStreamManager::sync(int i) { ...@@ -13,3 +13,24 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
} }
void CudaStreamManager::setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
printf("setup at device %d\n", device);
#endif
this->num_expert = num_expert;
if (device == -1) {
checkCudaErrors(cudaGetDevice(&this->device));
} else {
this->device = device;
}
checkCudaErrors(cudaSetDevice(this->device));
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
cudaStream_t* streams; cudaStream_t* streams;
public: public:
CudaStreamManager() : num_expert(0), device(0), streams(NULL) { CudaStreamManager() : num_expert(0), streams(NULL) {
int current_device; int current_device;
checkCudaErrors(cudaGetDevice(&current_device)); checkCudaErrors(cudaGetDevice(&current_device));
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
...@@ -24,21 +24,7 @@ public: ...@@ -24,21 +24,7 @@ public:
#endif #endif
} }
void setup(const size_t num_expert, const int device) { void setup(const size_t num_expert, const int device=-1);
#ifdef MOE_DEBUG
printf("setup at device %d\n", device);
#endif
this->num_expert = num_expert;
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
~CudaStreamManager() { ~CudaStreamManager() {
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
...@@ -54,6 +40,12 @@ public: ...@@ -54,6 +40,12 @@ public:
void sync(int=-1); void sync(int=-1);
}; };
#define ENSURE_SMGR(__smgr__, __num_expert__) { \
if (__smgr__.num_expert == 0) { \
__smgr__.setup(__num_expert__); \
} \
}
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device); // CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -4,10 +4,22 @@ ...@@ -4,10 +4,22 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight, // TODO: pass num-experts in another way?
torch::Tensor gate);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor pos);
torch::Tensor weight);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, torch::Tensor grad_output,
...@@ -22,20 +34,41 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -22,20 +34,41 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor weight,
torch::Tensor gate) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(weight, gate);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x hidden_feat x in_feat]
torch::Tensor weight // [num_expert x hidden_feat x in_feat] torch::Tensor expert_count // [batch_size]
) { ) {
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_forward(input, gate, weight); return moe_cuda_forward(input_buf, weight, expert_count);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
...@@ -69,6 +102,9 @@ int main() { ...@@ -69,6 +102,9 @@ int main() {
*/ */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
} }
...@@ -11,8 +11,12 @@ class MOEFunction(Function): ...@@ -11,8 +11,12 @@ class MOEFunction(Function):
def forward(ctx, inp, gate, weight): def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:] # out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight) expert_count, pos = moe_cuda.expert_count(weight, gate)
variables = [inp, gate, weight] input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [inp, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
...@@ -138,5 +142,5 @@ def test_dp(): ...@@ -138,5 +142,5 @@ def test_dp():
if __name__ == '__main__': if __name__ == '__main__':
# test() test()
test_dp() # test_dp()
\ No newline at end of file
...@@ -36,7 +36,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, ...@@ -36,7 +36,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_scatter_kernel(int wid, int* pos, void batch_scatter_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x; inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x]; oubuf += wid * pos[blockIdx.x];
...@@ -45,39 +45,14 @@ void batch_scatter_kernel(int wid, int* pos, ...@@ -45,39 +45,14 @@ void batch_scatter_kernel(int wid, int* pos,
} }
} }
template <typename scalar_t> void moe_cuda_expert_count_impl(
__global__
void batch_gather_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
const int* d_gate, const int* d_gate,
const scalar_t* weight, int* expert_count,
scalar_t* output, int* d_pos,
const size_t batch_size, const size_t num_expert,
const size_t in_feat, const size_t batch_size) {
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
scalar_t *input_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat));
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert]; int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert); memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
...@@ -92,19 +67,65 @@ void moe_cuda_forward_impl( ...@@ -92,19 +67,65 @@ void moe_cuda_forward_impl(
} }
int *pos = new int[batch_size]; int *pos = new int[batch_size];
int *d_pos;
checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++; pos[i] = expert_ptr[gate[i]]++;
} }
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
ENSURE_SMGR(smgr, num_expert);
}
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const int* d_pos,
scalar_t* input_buf,
const size_t batch_size,
const size_t in_feat) {
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
input_buf); input_buf);
smgr.sync(0); smgr.sync(0);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_local_gather_impl(
const scalar_t* output_buf,
const int* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
output);
smgr.sync(0);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
...@@ -126,17 +147,7 @@ void moe_cuda_forward_impl( ...@@ -126,17 +147,7 @@ void moe_cuda_forward_impl(
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr.sync();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
output);
smgr.sync(0);
cudaFree(input_buf);
cudaFree(output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -176,37 +187,107 @@ void moe_cuda_grad_weight( ...@@ -176,37 +187,107 @@ void moe_cuda_grad_weight(
delete[] gate_host; delete[] gate_host;
} }
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight,
torch::Tensor gate) {
const auto batch_size = gate.size(0);
const auto num_expert = weight.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
const auto batch_size = input.size(0);
const auto in_feat = input.size(1);
auto input_buf = torch::empty_like(input);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat);
}));
return {input_buf,};
}
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
const auto batch_size = output_buf.size(0);
const auto out_feat = output_buf.size(1);
auto output = torch::empty_like(output_buf);
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] {
moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat);
}));
return {output,};
}
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input_buf,
torch::Tensor gate, torch::Tensor weight,
torch::Tensor weight torch::Tensor expert_count
) { ) {
const auto batch_size = input.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif #endif
/*
const int device = device_of(input).value().index(); const int device = device_of(input).value().index();
if (smgr.streams == NULL) { if (smgr.streams == NULL) {
smgr.setup(num_expert, device); smgr.setup(num_expert, device);
} }
auto output = input.new_zeros({batch_size, out_feat}); */
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda",
moe_cuda_forward_impl<scalar_t>( ([&] {
input.data_ptr<scalar_t>(), moe_cuda_forward_impl<scalar_t>(
gate.data_ptr<int>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), expert_count.data_ptr<int>(),
batch_size, output.data_ptr<scalar_t>(),
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
CUBLAS_OP_T CUBLAS_OP_T
); );
})); }));
return {output, }; return {output, };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment