Commit 84bdd842 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

moe forward and backward

parent 1704dc36
import torch import torch
import moe1_cuda import moe_cuda
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
std::vector<torch::Tensor> moe1_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight); torch::Tensor weight);
std::vector<torch::Tensor> moe1_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, torch::Tensor grad_output,
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor gate,
...@@ -22,7 +22,7 @@ std::vector<torch::Tensor> moe1_cuda_backward( ...@@ -22,7 +22,7 @@ std::vector<torch::Tensor> moe1_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe1_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat] torch::Tensor input, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size] torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor weight // [num_expert x out_feat x in_feat]
...@@ -35,10 +35,10 @@ std::vector<torch::Tensor> moe1_forward( ...@@ -35,10 +35,10 @@ std::vector<torch::Tensor> moe1_forward(
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe1_cuda_forward(input, gate, weight); return moe_cuda_forward(input, gate, weight);
} }
std::vector<torch::Tensor> moe1_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor gate, // [batch_size]
...@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe1_backward( ...@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe1_backward(
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe1_cuda_forward(input, gate, weight); return moe_cuda_forward(input, gate, weight);
} }
...@@ -69,6 +69,6 @@ int main() { ...@@ -69,6 +69,6 @@ int main() {
*/ */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &moe1_forward, "MoE first linear forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
// m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
} }
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cassert>
#include <cuda.h> #include <cuda.h>
...@@ -40,6 +41,7 @@ Helper* getHelper(const size_t num_expert) { ...@@ -40,6 +41,7 @@ Helper* getHelper(const size_t num_expert) {
if (!helper) { if (!helper) {
helper = new Helper(num_expert); helper = new Helper(num_expert);
} }
assert(helper->num_expert == num_expert);
return helper; return helper;
} }
...@@ -63,8 +65,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, ...@@ -63,8 +65,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const float *Barray[], int ldb, const float *Barray[], int ldb,
const float *beta, const float *beta,
float *Carray[], int ldc, float *Carray[], int ldc,
int batchCount) int batchCount) {
{
return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
} }
...@@ -77,8 +78,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, ...@@ -77,8 +78,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const double *Barray[], int ldb, const double *Barray[], int ldb,
const double *beta, const double *beta,
double *Carray[], int ldc, double *Carray[], int ldc,
int batchCount) int batchCount) {
{
return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
} }
...@@ -91,14 +91,46 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, ...@@ -91,14 +91,46 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const __half *Barray[], int ldb, const __half *Barray[], int ldb,
const __half *beta, const __half *beta,
__half *Carray[], int ldc, __half *Carray[], int ldc,
int batchCount) int batchCount) {
{
return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
} }
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *A, int lda,
const double *B, int ldb,
const double *beta,
double *C, int ldc) {
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *A, int lda,
const __half *B, int ldb,
const __half *beta,
__half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
template <typename scalar_t> template <typename scalar_t>
void moe1_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input, const scalar_t* input,
const int* gate, const int* gate,
const scalar_t* weight, const scalar_t* weight,
...@@ -154,12 +186,47 @@ void moe1_cuda_forward_impl( ...@@ -154,12 +186,47 @@ void moe1_cuda_forward_impl(
batch_size)); batch_size));
checkCudaErrors(cudaStreamSynchronize(*(h->streams))); checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
// checkCudaErrors(cudaStreamDestroy(st));
// checkCudaErrors(cublasDestroy(handle));
} }
template <typename scalar_t>
void moe_cuda_grad_weight(
const scalar_t* input,
const int* gate,
const scalar_t* grad_output,
scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
Helper* h = getHelper(num_expert);
int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
checkCudaErrors(cublasSgemm(h->handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
out_feat,
in_feat,
1,
&alpha,
grad_output + i * out_feat,
out_feat,
input + i * in_feat,
1,
&beta,
grad_weight + gate_host[i] * out_feat * in_feat,
out_feat));
}
checkCudaErrors(cudaDeviceSynchronize());
delete[] gate_host;
}
std::vector<torch::Tensor> moe1_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight) { torch::Tensor weight) {
...@@ -171,8 +238,8 @@ std::vector<torch::Tensor> moe1_cuda_forward( ...@@ -171,8 +238,8 @@ std::vector<torch::Tensor> moe1_cuda_forward(
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k); // printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
moe1_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
gate.data_ptr<int>(), gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
...@@ -188,7 +255,7 @@ std::vector<torch::Tensor> moe1_cuda_forward( ...@@ -188,7 +255,7 @@ std::vector<torch::Tensor> moe1_cuda_forward(
return {output, }; return {output, };
} }
std::vector<torch::Tensor> moe1_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor gate, // [batch_size]
...@@ -201,9 +268,10 @@ std::vector<torch::Tensor> moe1_cuda_backward( ...@@ -201,9 +268,10 @@ std::vector<torch::Tensor> moe1_cuda_backward(
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_cuda_backward", ([&] { // grad_input is easy to compute, exactly the same as forward
moe1_cuda_forward_impl<scalar_t>( AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(), gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
......
...@@ -2,10 +2,10 @@ from setuptools import setup ...@@ -2,10 +2,10 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='moe1_cuda', name='moe_cuda',
ext_modules=[ ext_modules=[
CUDAExtension( CUDAExtension(
name='moe1_cuda', name='moe_cuda',
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment