Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
84bdd842
Commit
84bdd842
authored
Dec 29, 2020
by
Jiezhong Qiu
Browse files
moe forward and backward
parent
1704dc36
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
27 deletions
+95
-27
pytorch/cuda/a.py
pytorch/cuda/a.py
+1
-1
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+8
-8
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+84
-16
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+2
-2
No files found.
pytorch/cuda/a.py
View file @
84bdd842
import
torch
import
moe
1
_cuda
import
moe_cuda
pytorch/cuda/moe.cpp
View file @
84bdd842
...
...
@@ -4,12 +4,12 @@
#include <iostream>
#include <vector>
std
::
vector
<
torch
::
Tensor
>
moe
1
_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
);
std
::
vector
<
torch
::
Tensor
>
moe
1
_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
...
...
@@ -22,7 +22,7 @@ std::vector<torch::Tensor> moe1_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe
1
_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
...
...
@@ -35,10 +35,10 @@ std::vector<torch::Tensor> moe1_forward(
Wx+b = [W b] [x]
[1]
*/
return
moe
1
_cuda_forward
(
input
,
gate
,
weight
);
return
moe_cuda_forward
(
input
,
gate
,
weight
);
}
std
::
vector
<
torch
::
Tensor
>
moe
1
_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size]
...
...
@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe1_backward(
Wx+b = [W b] [x]
[1]
*/
return
moe
1
_cuda_forward
(
input
,
gate
,
weight
);
return
moe_cuda_forward
(
input
,
gate
,
weight
);
}
...
...
@@ -69,6 +69,6 @@ int main() {
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
moe
1
_forward
,
"MoE
first linear
forward (CUDA)"
);
//
m.def("backward", &
lltm
_backward, "
LLTM
backward (CUDA)");
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe
_backward
,
"
MoE
backward (CUDA)"
);
}
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
View file @
84bdd842
...
...
@@ -3,6 +3,7 @@
#include <cstdio>
#include <iostream>
#include <vector>
#include <cassert>
#include <cuda.h>
...
...
@@ -40,6 +41,7 @@ Helper* getHelper(const size_t num_expert) {
if
(
!
helper
)
{
helper
=
new
Helper
(
num_expert
);
}
assert
(
helper
->
num_expert
==
num_expert
);
return
helper
;
}
...
...
@@ -63,8 +65,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
...
...
@@ -77,8 +78,7 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
...
...
@@ -91,14 +91,46 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasSgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
A
,
int
lda
,
const
double
*
B
,
int
ldb
,
const
double
*
beta
,
double
*
C
,
int
ldc
)
{
return
cublasDgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
A
,
int
lda
,
const
__half
*
B
,
int
ldb
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
)
{
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
typename
scalar_t
>
void
moe
1
_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
int
*
gate
,
const
scalar_t
*
weight
,
...
...
@@ -154,12 +186,47 @@ void moe1_cuda_forward_impl(
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
)));
// checkCudaErrors(cudaStreamDestroy(st));
// checkCudaErrors(cublasDestroy(handle));
}
template
<
typename
scalar_t
>
void
moe_cuda_grad_weight
(
const
scalar_t
*
input
,
const
int
*
gate
,
const
scalar_t
*
grad_output
,
scalar_t
*
grad_weight
,
// [num_expert x out_feat x in_feat]
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
Helper
*
h
=
getHelper
(
num_expert
);
int
*
gate_host
=
new
int
[
batch_size
];
scalar_t
alpha
=
1
,
beta
=
1
;
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
+
gate_host
[
i
])));
checkCudaErrors
(
cublasSgemm
(
h
->
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
in_feat
,
1
,
&
alpha
,
grad_output
+
i
*
out_feat
,
out_feat
,
input
+
i
*
in_feat
,
1
,
&
beta
,
grad_weight
+
gate_host
[
i
]
*
out_feat
*
in_feat
,
out_feat
));
}
checkCudaErrors
(
cudaDeviceSynchronize
());
delete
[]
gate_host
;
}
std
::
vector
<
torch
::
Tensor
>
moe
1
_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
)
{
...
...
@@ -171,8 +238,8 @@ std::vector<torch::Tensor> moe1_cuda_forward(
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe
1
_forward_cuda"
,
([
&
]
{
moe
1
_cuda_forward_impl
<
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -188,7 +255,7 @@ std::vector<torch::Tensor> moe1_cuda_forward(
return
{
output
,
};
}
std
::
vector
<
torch
::
Tensor
>
moe
1
_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size]
...
...
@@ -202,8 +269,9 @@ std::vector<torch::Tensor> moe1_cuda_backward(
auto
grad_input
=
grad_output
.
new_zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe1_cuda_backward"
,
([
&
]
{
moe1_cuda_forward_impl
<
scalar_t
>
(
// grad_input is easy to compute, exactly the same as forward
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
grad_output
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
...
...
pytorch/cuda/setup.py
View file @
84bdd842
...
...
@@ -2,10 +2,10 @@ from setuptools import setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
'moe
1
_cuda'
,
name
=
'moe_cuda'
,
ext_modules
=
[
CUDAExtension
(
name
=
'moe
1
_cuda'
,
name
=
'moe_cuda'
,
sources
=
[
'moe.cpp'
,
'moe_cuda_kernel.cu'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment