Commit 707652bc authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 74cc6ec2
...@@ -4,4 +4,6 @@ libtorch-shared-with-deps-* ...@@ -4,4 +4,6 @@ libtorch-shared-with-deps-*
pytorch/cuda/build pytorch/cuda/build
exp/ exp/
.vscode/ .vscode/
a.out a.out
\ No newline at end of file moe_first_linear_cuda.egg-info
*.egg
\ No newline at end of file
...@@ -14,12 +14,33 @@ ...@@ -14,12 +14,33 @@
//#include <helper_functions.h> //#include <helper_functions.h>
#include <helper_cuda.h> #include <helper_cuda.h>
template <typename scalar_t>
void moe_first_linear_cuda_forward(
const scalar_t* input,
const size_t* gate,
const scalar_t* weight,
scalar_t* output,
const size_t batch_size,
const size_t top_k,
const size_t in_feat,
const size_t out_feat);
std::vector<torch::Tensor> moe_cuda_forward( // C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_first_linear_forward(
torch::Tensor input, // [B x D_model] torch::Tensor input, // [B x D_model]
torch::Tensor gate, // [B x K] torch::Tensor gate, // [B x K]
torch::Tensor weight // [N x D_ffn x D_model] torch::Tensor weight // [N x D_ffn x D_model]
) { ) {
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
...@@ -31,11 +52,11 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -31,11 +52,11 @@ std::vector<torch::Tensor> moe_cuda_forward(
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
printf("b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d\n", batch_size, num_expert, in_feat, out_feat, top_k); printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto output = input.new_zeros({batch_size, top_k, out_feat}); auto output = input.new_zeros({batch_size, top_k, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_forward", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_first_linear_forward", ([&] {
moe_cuda_forward_impl<scalar_t>( moe_first_linear_cuda_forward<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
gate.data_ptr<size_t>(), gate.data_ptr<size_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
...@@ -49,14 +70,8 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -49,14 +70,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
return {output, }; return {output, };
} }
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/*
int main() { int main() {
int device=2; int device=2;
torch::Tensor input = torch::randn({2048, 512}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device)); torch::Tensor input = torch::randn({2048, 512}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
...@@ -64,4 +79,10 @@ int main() { ...@@ -64,4 +79,10 @@ int main() {
torch::Tensor weight = torch::randn({2, 512, 2048}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device)); torch::Tensor weight = torch::randn({2, 512, 2048}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
checkCudaErrors(cudaSetDevice(device)); checkCudaErrors(cudaSetDevice(device));
moe_cuda_forward(input, gate, weight); moe_cuda_forward(input, gate, weight);
}
*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &moe_first_linear_forward, "MoE first linear forward (CUDA)");
// m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
} }
\ No newline at end of file
...@@ -75,7 +75,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, ...@@ -75,7 +75,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_first_linear_cuda_forward(
const scalar_t* input, const scalar_t* input,
const size_t* gate, const size_t* gate,
const scalar_t* weight, const scalar_t* weight,
...@@ -155,11 +155,11 @@ int main() { ...@@ -155,11 +155,11 @@ int main() {
} }
checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));
moe_cuda_forward_impl<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat); moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
for (size_t i=0; i<nt; ++i) { for (size_t i=0; i<nt; ++i) {
timestamp(start); timestamp(start);
moe_cuda_forward_impl<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat); moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
timestamp(end); timestamp(end);
auto t = getDuration(start, end); auto t = getDuration(start, end);
tsum += t; tsum += t;
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='moe_first_linear_cuda',
ext_modules=[
CUDAExtension(
name='moe_first_linear_cuda',
sources=[
'moe.cpp',
'moe_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-I/usr/local/cuda/samples/common/inc'],
'nvcc': ['-I/usr/local/cuda/samples/common/inc']}
)
],
cmdclass={
'build_ext': BuildExtension
})
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment