Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
74cc6ec2
Commit
74cc6ec2
authored
Dec 18, 2020
by
Jiezhong Qiu
Browse files
0.4T Flops
parent
eb1525ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
5 deletions
+28
-5
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+28
-5
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
74cc6ec2
...
@@ -11,13 +11,15 @@
...
@@ -11,13 +11,15 @@
//#include <helper_functions.h>
//#include <helper_functions.h>
#include <helper_cuda.h>
#include <helper_cuda.h>
#include "timer.hh"
typedef
float
data_t
;
typedef
float
data_t
;
size_t
batch_size
=
4096
;
size_t
batch_size
=
4096
;
size_t
top_k
=
2
;
size_t
top_k
=
2
;
size_t
num_expert
=
128
;
size_t
num_expert
=
128
;
size_t
in_feat
=
512
;
size_t
in_feat
=
1024
;
size_t
out_feat
=
2048
;
size_t
out_feat
=
4096
;
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
...
@@ -126,7 +128,7 @@ void moe_cuda_forward_impl(
...
@@ -126,7 +128,7 @@ void moe_cuda_forward_impl(
Barray
,
out_feat
,
Barray
,
out_feat
,
&
beta
,
&
beta
,
Carray
,
1
,
Carray
,
1
,
batch_size
));
batch_size
*
top_k
));
checkCudaErrors
(
cudaStreamSynchronize
(
st
));
checkCudaErrors
(
cudaStreamSynchronize
(
st
));
checkCudaErrors
(
cudaStreamDestroy
(
st
));
checkCudaErrors
(
cudaStreamDestroy
(
st
));
...
@@ -142,7 +144,28 @@ int main() {
...
@@ -142,7 +144,28 @@ int main() {
checkCudaErrors
(
cudaMalloc
(
&
input
,
batch_size
*
in_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
input
,
batch_size
*
in_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
weight
,
num_expert
*
in_feat
*
out_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
weight
,
num_expert
*
in_feat
*
out_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
output
,
batch_size
*
top_k
*
out_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
output
,
batch_size
*
top_k
*
out_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
gate
,
batch_size
*
top_k
*
sizeof
(
size_t
)));
checkCudaErrors
(
cudaMalloc
(
&
gate
,
batch_size
*
top_k
*
sizeof
(
size_t
)));
size_t
nt
=
16
;
double
tsum
=
0
,
tmax
=
0
;
moe_cuda_forward_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
size_t
*
gate_host
=
new
size_t
[
batch_size
*
top_k
];
for
(
size_t
i
=
0
;
i
<
batch_size
*
top_k
;
++
i
)
{
gate_host
[
i
]
=
rand
()
%
num_expert
;
}
checkCudaErrors
(
cudaMemcpy
(
gate
,
gate_host
,
batch_size
*
top_k
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
));
moe_cuda_forward_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
for
(
size_t
i
=
0
;
i
<
nt
;
++
i
)
{
timestamp
(
start
);
moe_cuda_forward_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
timestamp
(
end
);
auto
t
=
getDuration
(
start
,
end
);
tsum
+=
t
;
if
(
t
>
tmax
)
tmax
=
t
;
}
printf
(
"Mean %.3lf us, max %.3lf us
\n
"
,
tsum
/
nt
*
1e6
,
tmax
*
1e6
);
double
tflops
=
(
double
)
batch_size
*
top_k
*
in_feat
*
out_feat
*
nt
*
2e-12
/
tsum
;
printf
(
"%.3lf TFLOPs
\n
"
,
tflops
);
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment