Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3a458fa7
Commit
3a458fa7
authored
Dec 23, 2020
by
Jiezhong Qiu
Browse files
updarte
parent
707652bc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
69 deletions
+63
-69
.gitignore
.gitignore
+1
-1
pytorch/cuda/a.py
pytorch/cuda/a.py
+2
-0
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+7
-45
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+51
-21
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+2
-2
No files found.
.gitignore
View file @
3a458fa7
...
@@ -5,5 +5,5 @@ pytorch/cuda/build
...
@@ -5,5 +5,5 @@ pytorch/cuda/build
exp/
exp/
.vscode/
.vscode/
a.out
a.out
moe_first_linear_cuda
.egg-info
*
.egg-info
*.egg
*.egg
\ No newline at end of file
pytorch/cuda/a.py
0 → 100644
View file @
3a458fa7
import
torch
import
moe1_cuda
pytorch/cuda/moe.cpp
View file @
3a458fa7
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdio>
#include <cstdio>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
// CUDA runtime
std
::
vector
<
torch
::
Tensor
>
moe1_cuda_forward
(
#include <cuda.h>
torch
::
Tensor
input
,
#include <cuda_runtime.h>
torch
::
Tensor
gate
,
#include <cublas_v2.h>
torch
::
Tensor
weight
);
// CUDA and CUBLAS functions
//#include <helper_functions.h>
#include <helper_cuda.h>
template
<
typename
scalar_t
>
void
moe_first_linear_cuda_forward
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
top_k
,
const
size_t
in_feat
,
const
size_t
out_feat
);
// C++ interface
// C++ interface
...
@@ -32,8 +16,7 @@ void moe_first_linear_cuda_forward(
...
@@ -32,8 +16,7 @@ void moe_first_linear_cuda_forward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe1_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_first_linear_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
...
@@ -46,28 +29,7 @@ std::vector<torch::Tensor> moe_first_linear_forward(
...
@@ -46,28 +29,7 @@ std::vector<torch::Tensor> moe_first_linear_forward(
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
const
auto
batch_size
=
input
.
size
(
0
);
return
moe1_cuda_forward
(
input
,
gate
,
weight
);
const
auto
top_k
=
gate
.
size
(
1
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
printf
(
"b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
,
top_k
);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_first_linear_forward"
,
([
&
]
{
moe_first_linear_cuda_forward
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
size_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
top_k
,
in_feat
,
out_feat
);
}));
return
{
output
,
};
}
}
...
@@ -83,6 +45,6 @@ int main() {
...
@@ -83,6 +45,6 @@ int main() {
*/
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
moe
_first_linear
_forward
,
"MoE first linear forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe
1
_forward
,
"MoE first linear forward (CUDA)"
);
// m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
// m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}
}
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
View file @
3a458fa7
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdio>
#include <cstdio>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
// CUDA runtime
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
// CUDA and CUBLAS functions
//#include <helper_functions.h>
#include <helper_cuda.h>
#include "timer.hh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
typedef
float
data_t
;
// #include "timer.hh"
size_t
batch_size
=
4096
;
size_t
top_k
=
2
;
size_t
num_expert
=
128
;
size_t
in_feat
=
1024
;
size_t
out_feat
=
4096
;
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
size_
t
*
offset
,
const
scalar_t
**
ptrs
)
{
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
in
t
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
}
}
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
cublasOperation_t
transb
,
...
@@ -74,10 +66,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
...
@@ -74,10 +66,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe
_first_linear
_cuda_forward
(
void
moe
1
_cuda_forward
_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
const
size_
t
*
gate
,
const
in
t
*
gate
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
batch_size
,
...
@@ -85,7 +78,6 @@ void moe_first_linear_cuda_forward(
...
@@ -85,7 +78,6 @@ void moe_first_linear_cuda_forward(
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
)
{
const
size_t
out_feat
)
{
cublasHandle_t
handle
;
cublasHandle_t
handle
;
cudaStream_t
st
;
cudaStream_t
st
;
cudaStreamCreate
(
&
st
);
cudaStreamCreate
(
&
st
);
...
@@ -136,7 +128,44 @@ void moe_first_linear_cuda_forward(
...
@@ -136,7 +128,44 @@ void moe_first_linear_cuda_forward(
}
}
std
::
vector
<
torch
::
Tensor
>
moe1_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
top_k
=
gate
.
size
(
1
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe1_forward_cuda"
,
([
&
]
{
moe1_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
top_k
,
in_feat
,
out_feat
);
}));
return
{
output
,
};
}
/*
int main() {
int main() {
typedef float data_t;
size_t batch_size = 4096;
size_t top_k = 2;
size_t num_expert = 128;
size_t in_feat = 1024;
size_t out_feat = 4096;
data_t *input, *weight;
data_t *input, *weight;
data_t *output;
data_t *output;
size_t *gate;
size_t *gate;
...
@@ -168,4 +197,5 @@ int main() {
...
@@ -168,4 +197,5 @@ int main() {
printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
printf("%.3lf TFLOPs\n", tflops);
printf("%.3lf TFLOPs\n", tflops);
}
}
\ No newline at end of file
*/
\ No newline at end of file
pytorch/cuda/setup.py
View file @
3a458fa7
...
@@ -2,10 +2,10 @@ from setuptools import setup
...
@@ -2,10 +2,10 @@ from setuptools import setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
setup
(
name
=
'moe
_first_linear
_cuda'
,
name
=
'moe
1
_cuda'
,
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
CUDAExtension
(
name
=
'moe
_first_linear
_cuda'
,
name
=
'moe
1
_cuda'
,
sources
=
[
sources
=
[
'moe.cpp'
,
'moe.cpp'
,
'moe_cuda_kernel.cu'
,
'moe_cuda_kernel.cu'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment