Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1704dc36
Commit
1704dc36
authored
Dec 23, 2020
by
Jiezhong Qiu
Browse files
update
parent
eb47044a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
14 deletions
+105
-14
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+27
-3
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+78
-11
No files found.
pytorch/cuda/moe.cpp
View file @
1704dc36
...
...
@@ -9,6 +9,12 @@ std::vector<torch::Tensor> moe1_cuda_forward(
torch
::
Tensor
gate
,
torch
::
Tensor
weight
);
std
::
vector
<
torch
::
Tensor
>
moe1_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
...
...
@@ -17,10 +23,28 @@ std::vector<torch::Tensor> moe1_cuda_forward(
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe1_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return
moe1_cuda_forward
(
input
,
gate
,
weight
);
}
std
::
vector
<
torch
::
Tensor
>
moe1_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
)
{
CHECK_INPUT
(
grad_output
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
1704dc36
...
...
@@ -14,6 +14,36 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
class
Helper
{
public:
Helper
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
));
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
}
}
~
Helper
()
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
const
size_t
num_expert
;
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
};
Helper
*
helper
=
NULL
;
Helper
*
getHelper
(
const
size_t
num_expert
)
{
if
(
!
helper
)
{
helper
=
new
Helper
(
num_expert
);
}
return
helper
;
}
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
int
*
offset
,
const
scalar_t
**
ptrs
)
{
...
...
@@ -75,13 +105,18 @@ void moe1_cuda_forward_impl(
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
)
{
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
/*
cublasHandle_t handle;
cudaStream_t st;
cudaStreamCreate
(
&
st
);
checkCudaErrors(
cudaStreamCreate(&st)
)
;
checkCudaErrors(cublasCreate(&handle));
checkCudaErrors
(
cublasSetStream
(
handle
,
st
));
*/
Helper
*
h
=
getHelper
(
num_expert
);
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
// setup Aarray, Barray and Carray
std
::
vector
<
const
scalar_t
*>
aptrs
;
...
...
@@ -104,12 +139,12 @@ void moe1_cuda_forward_impl(
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
st
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
handle
,
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
transb
,
1
,
out_feat
,
in_feat
,
&
alpha
,
Aarray
,
1
,
...
...
@@ -118,9 +153,9 @@ void moe1_cuda_forward_impl(
Carray
,
1
,
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
st
));
checkCudaErrors
(
cudaStreamDestroy
(
st
));
checkCudaErrors
(
cublasDestroy
(
handle
));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
)
));
//
checkCudaErrors(cudaStreamDestroy(st));
//
checkCudaErrors(cublasDestroy(handle));
}
...
...
@@ -144,13 +179,45 @@ std::vector<torch::Tensor> moe1_cuda_forward(
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
out_feat
out_feat
,
num_expert
,
CUBLAS_OP_T
);
}));
return
{
output
,
};
}
std
::
vector
<
torch
::
Tensor
>
moe1_cuda_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
auto
grad_input
=
grad_output
.
new_zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe1_cuda_backward"
,
([
&
]
{
moe1_cuda_forward_impl
<
scalar_t
>
(
grad_output
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
grad_input
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
in_feat
,
num_expert
,
CUBLAS_OP_N
);
}));
return
{
grad_input
,
grad_weight
};
}
/*
int main() {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment