Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
49c97411
Commit
49c97411
authored
Jan 28, 2021
by
Rick Ho
Browse files
support fp16
parent
952e3135
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
7 deletions
+24
-7
cuda/cublas_wrapper.h
cuda/cublas_wrapper.h
+17
-0
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+2
-2
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+4
-4
cuda/moe_fused_kernel.cu
cuda/moe_fused_kernel.cu
+1
-1
No files found.
cuda/cublas_wrapper.h
View file @
49c97411
#ifndef CUBLAS_WRAPPER_H
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <c10/util/Half.h>
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transa
,
...
@@ -74,5 +75,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
...
@@ -74,5 +75,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
__half
*
C
,
int
ldc
)
{
__half
*
C
,
int
ldc
)
{
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
c10
::
Half
*
alpha
,
const
c10
::
Half
*
A
,
int
lda
,
const
c10
::
Half
*
B
,
int
ldb
,
const
c10
::
Half
*
beta
,
c10
::
Half
*
C
,
int
ldc
)
{
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
__half
*
)
alpha
,
(
const
__half
*
)
A
,
lda
,
(
const
__half
*
)
B
,
ldb
,
(
const
__half
*
)
beta
,
(
__half
*
)
C
,
ldc
);
}
#endif // CUBLAS_WRAPPER_H
#endif // CUBLAS_WRAPPER_H
cuda/moe_comm_kernel.cu
View file @
49c97411
...
@@ -112,7 +112,7 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
...
@@ -112,7 +112,7 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_global_scatter"
,
([
&
]
{
"moe_cuda_global_scatter"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
moe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
@@ -182,7 +182,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -182,7 +182,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_cuda_global_gather"
,
([
&
]
{
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_gather_impl
<
scalar_t
>
(
moe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
...
...
cuda/moe_compute_kernel.cu
View file @
49c97411
...
@@ -233,7 +233,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
...
@@ -233,7 +233,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
auto
input_buf
=
torch
::
empty_like
(
input
);
auto
input_buf
=
torch
::
empty_like
(
input
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
...
@@ -255,7 +255,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -255,7 +255,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
auto
output
=
torch
::
empty_like
(
output_buf
);
auto
output
=
torch
::
empty_like
(
output_buf
);
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
...
@@ -288,7 +288,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -288,7 +288,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
.
dtype
(
input_buf
.
dtype
());
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
@@ -326,7 +326,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -326,7 +326,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_backward_impl
<
scalar_t
>
(
moe_cuda_backward_impl
<
scalar_t
>
(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
...
cuda/moe_fused_kernel.cu
View file @
49c97411
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
auto
global_input_buf
=
input_buf
.
new_empty
({
global_batch_size
,
in_feat
});
auto
global_input_buf
=
input_buf
.
new_empty
({
global_batch_size
,
in_feat
});
auto
global_output_buf
=
input_buf
.
new_empty
({
global_batch_size
,
out_feat
});
auto
global_output_buf
=
input_buf
.
new_empty
({
global_batch_size
,
out_feat
});
auto
output_buf
=
input_buf
.
new_empty
({
local_batch_size
,
out_feat
});
auto
output_buf
=
input_buf
.
new_empty
({
local_batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES
_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_global_fused_forward"
,
([
&
]
{
"moe_cuda_global_fused_forward"
,
([
&
]
{
moe_cuda_global_fused_forward_impl
(
moe_cuda_global_fused_forward_impl
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment