Commit 1704dc36 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent eb47044a
......@@ -9,6 +9,12 @@ std::vector<torch::Tensor> moe1_cuda_forward(
torch::Tensor gate,
torch::Tensor weight);
std::vector<torch::Tensor> moe1_cuda_backward(
torch::Tensor grad_output,
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
......@@ -17,10 +23,28 @@ std::vector<torch::Tensor> moe1_cuda_forward(
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe1_forward(
torch::Tensor input, // [B x D_model]
torch::Tensor gate, // [B]
torch::Tensor weight // [N x D_ffn x D_model]
torch::Tensor input, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe1_cuda_forward(input, gate, weight);
}
std::vector<torch::Tensor> moe1_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(weight);
......
......@@ -14,6 +14,36 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
class Helper {
public:
Helper(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
}
~Helper() {
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
}
checkCudaErrors(cublasDestroy(handle));
}
const size_t num_expert;
cublasHandle_t handle;
cudaStream_t* streams;
};
Helper* helper = NULL;
Helper* getHelper(const size_t num_expert) {
if (!helper) {
helper = new Helper(num_expert);
}
return helper;
}
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
......@@ -75,13 +105,18 @@ void moe1_cuda_forward_impl(
scalar_t* output,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat) {
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
/*
cublasHandle_t handle;
cudaStream_t st;
cudaStreamCreate(&st);
checkCudaErrors(cudaStreamCreate(&st));
checkCudaErrors(cublasCreate(&handle));
checkCudaErrors(cublasSetStream(handle, st));
*/
Helper* h = getHelper(num_expert);
checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
// setup Aarray, Barray and Carray
std::vector<const scalar_t*> aptrs;
......@@ -104,12 +139,12 @@ void moe1_cuda_forward_impl(
dim3 griddim(CEIL(batch_size, 256));
dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0, st>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
generate_ptr_offset_kernel<<<griddim, blockdim, 0, *(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(handle,
checkCudaErrors(cublasXgemmBatched(h->handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
transb,
1, out_feat, in_feat,
&alpha,
Aarray, 1,
......@@ -118,9 +153,9 @@ void moe1_cuda_forward_impl(
Carray, 1,
batch_size));
checkCudaErrors(cudaStreamSynchronize(st));
checkCudaErrors(cudaStreamDestroy(st));
checkCudaErrors(cublasDestroy(handle));
checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
// checkCudaErrors(cudaStreamDestroy(st));
// checkCudaErrors(cublasDestroy(handle));
}
......@@ -144,13 +179,45 @@ std::vector<torch::Tensor> moe1_cuda_forward(
output.data_ptr<scalar_t>(),
batch_size,
in_feat,
out_feat
out_feat,
num_expert,
CUBLAS_OP_T
);
}));
return {output, };
}
std::vector<torch::Tensor> moe1_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
const auto batch_size = input.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_cuda_backward", ([&] {
moe1_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
batch_size,
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
return {grad_input, grad_weight};
}
/*
int main() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment