Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
707652bc
Commit
707652bc
authored
Dec 19, 2020
by
Jiezhong Qiu
Browse files
update
parent
74cc6ec2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
15 deletions
+57
-15
.gitignore
.gitignore
+3
-1
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+32
-11
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+3
-3
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+19
-0
No files found.
.gitignore
View file @
707652bc
...
...
@@ -5,3 +5,5 @@ pytorch/cuda/build
exp/
.vscode/
a.out
moe_first_linear_cuda.egg-info
*.egg
\ No newline at end of file
pytorch/cuda/moe.cpp
View file @
707652bc
...
...
@@ -14,12 +14,33 @@
//#include <helper_functions.h>
#include <helper_cuda.h>
template
<
typename
scalar_t
>
void
moe_first_linear_cuda_forward
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
top_k
,
const
size_t
in_feat
,
const
size_t
out_feat
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_first_linear_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
...
...
@@ -31,11 +52,11 @@ std::vector<torch::Tensor> moe_cuda_forward(
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
printf
(
"b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
,
top_k
);
printf
(
"b=%
l
d, expert=%
l
d, in_feat (d_model)=%
l
d, out_feat (d_ffn)=%
l
d, topk=%
l
d
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
,
top_k
);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_
cuda
_forward"
,
([
&
]
{
moe_cuda_forward
_impl
<
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_
first_linear
_forward"
,
([
&
]
{
moe_
first_linear_
cuda_forward
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
size_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -49,14 +70,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
return
{
output
,
};
}
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/*
int main() {
int device=2;
torch::Tensor input = torch::randn({2048, 512}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
...
...
@@ -65,3 +80,9 @@ int main() {
checkCudaErrors(cudaSetDevice(device));
moe_cuda_forward(input, gate, weight);
}
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
moe_first_linear_forward
,
"MoE first linear forward (CUDA)"
);
// m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
View file @
707652bc
...
...
@@ -75,7 +75,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
}
template
<
typename
scalar_t
>
void
moe_cuda_forward
_impl
(
void
moe_
first_linear_
cuda_forward
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
...
...
@@ -155,11 +155,11 @@ int main() {
}
checkCudaErrors
(
cudaMemcpy
(
gate
,
gate_host
,
batch_size
*
top_k
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
));
moe_cuda_forward
_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
moe_
first_linear_
cuda_forward
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
for
(
size_t
i
=
0
;
i
<
nt
;
++
i
)
{
timestamp
(
start
);
moe_cuda_forward
_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
moe_
first_linear_
cuda_forward
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
timestamp
(
end
);
auto
t
=
getDuration
(
start
,
end
);
tsum
+=
t
;
...
...
pytorch/cuda/setup.py
0 → 100644
View file @
707652bc
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
'moe_first_linear_cuda'
,
ext_modules
=
[
CUDAExtension
(
name
=
'moe_first_linear_cuda'
,
sources
=
[
'moe.cpp'
,
'moe_cuda_kernel.cu'
,
],
extra_compile_args
=
{
'cxx'
:
[
'-I/usr/local/cuda/samples/common/inc'
],
'nvcc'
:
[
'-I/usr/local/cuda/samples/common/inc'
]}
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
})
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment