Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a6db9526
Commit
a6db9526
authored
Dec 16, 2020
by
Jiezhong Qiu
Browse files
update
parent
ca3ece2c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
27 deletions
+44
-27
.gitignore
.gitignore
+2
-1
pytorch/cuda/CMakeLists.txt
pytorch/cuda/CMakeLists.txt
+2
-2
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+40
-24
No files found.
.gitignore
View file @
a6db9526
...
...
@@ -2,4 +2,5 @@
data/
libtorch-shared-with-deps-*
pytorch/cuda/build
exp/
\ No newline at end of file
exp/
.vscode/
\ No newline at end of file
pytorch/cuda/CMakeLists.txt
View file @
a6db9526
...
...
@@ -4,8 +4,8 @@ project(moe)
find_package
(
Torch REQUIRED
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
include_directories
(
"/home/jiezhong/
ana
conda3/
envs/torch/
include/python3.
6m
"
"/usr/local/cuda
-
/include"
include_directories
(
"/home/jiezhong/
mini
conda3/include/python3.
8
"
"/usr/local/cuda/include"
"/usr/local/cuda/samples/common/inc"
)
add_executable
(
moe moe.cpp
)
target_link_libraries
(
moe
...
...
pytorch/cuda/moe.cpp
View file @
a6db9526
...
...
@@ -14,7 +14,7 @@
#include <helper_cuda.h>
const
int
num_stream
=
16
;
const
int
num_stream
=
512
;
// std::vector<torch::Tensor>
void
moe_cuda_forward
(
...
...
@@ -27,8 +27,8 @@ void moe_cuda_forward(
const
auto
num_expert
=
gate
.
size
(
1
);
const
auto
d_model
=
weight
.
size
(
1
);
const
auto
d_ffn
=
weight
.
size
(
2
);
printf
(
"b=%d, expert=%d, d_model=%d, d_ffn=%d
\n
"
,
batch_size
,
num_expert
,
d_model
,
d_ffn
);
auto
output
=
input
.
new_zeros
({
batch_size
,
num_expert
,
d_ffn
});
std
::
cout
<<
output
<<
std
::
endl
;
cublasHandle_t
handle
;
...
...
@@ -39,50 +39,66 @@ void moe_cuda_forward(
checkCudaErrors
(
cudaStreamCreate
(
&
stream
[
i
]));
}
cudaEvent_t
start
,
stop
;
checkCudaErrors
(
cudaEventCreate
(
&
start
));
checkCudaErrors
(
cudaEventCreate
(
&
stop
));
// Record the start event
checkCudaErrors
(
cudaEventRecord
(
start
,
NULL
));
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
printf
(
"i=%d j=%d goes to stream %d
\n
"
,
i
,
j
,
s
);
cublasSetStream
(
handle
,
stream
[
s
]);
//
printf("i=%d j=%d goes to stream %d\n", i, j, s);
checkCudaErrors
(
cublasSetStream
(
handle
,
stream
[
s
])
)
;
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Float
)
{
float
alpha
=
1.0
;
float
beta
=
0.0
;
std
::
cout
<<
input
[
i
]
<<
std
::
endl
;
std
::
cout
<<
weight
.
index
(
gate
[
i
][
j
])
<<
std
::
endl
;
std
::
cout
<<
output
[
i
][
j
]
<<
std
::
endl
;
cublasSgemm
(
handle
,
checkCudaErrors
(
cublasSgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
// m
d_ffn
,
// n
d_model
,
// k
&
alpha
,
input
.
data_ptr
<
float
>
()
+
i
*
d_model
,
// input[i].data_ptr<float>(),
input
[
i
].
data_ptr
<
float
>
(),
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
float
>
(),
d_model
,
&
beta
,
output
.
data_ptr
<
float
>
()
+
i
*
num_expert
*
d_ffn
+
j
*
d_ffn
,
1
);
output
[
i
][
j
]
.
data_ptr
<
float
>
(),
1
)
)
;
}
else
{
printf
(
"only support float!!!
\n
"
);
}
}
}
cudaDeviceSynchronize
();
printf
(
"synchronized
\n
"
);
// checkCudaErrors(cudaDeviceSynchronize());
// Record the stop event
checkCudaErrors
(
cudaEventRecord
(
stop
,
NULL
));
// Wait for the stop event to complete
checkCudaErrors
(
cudaEventSynchronize
(
stop
));
float
msecTotal
=
0.0
f
;
checkCudaErrors
(
cudaEventElapsedTime
(
&
msecTotal
,
start
,
stop
));
// Compute and print the performance
float
msecPerMatrixMul
=
msecTotal
/
batch_size
/
num_expert
;
double
flopsPerMatrixMul
=
2.0
*
(
double
)
d_model
*
(
double
)
d_ffn
;
double
gigaFlops
=
(
flopsPerMatrixMul
*
1.0e-9
f
)
/
(
msecPerMatrixMul
/
1000.0
f
);
printf
(
"Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops
\n
"
,
gigaFlops
,
msecPerMatrixMul
,
flopsPerMatrixMul
);
// std::cout << output << std::endl;
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
cudaStreamDestroy
(
stream
[
i
]);
checkCudaErrors
(
cudaStreamDestroy
(
stream
[
i
])
)
;
}
std
::
cout
<<
output
<<
std
::
endl
;
cublasDestroy
(
handle
);
checkCudaErrors
(
cublasDestroy
(
handle
));
}
...
...
@@ -96,10 +112,10 @@ void moe_cuda_forward(
int
main
()
{
int
device
=
2
;
torch
::
Tensor
input
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2
,
1
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
4
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
std
::
cout
<<
input
<<
std
::
endl
;
torch
::
Tensor
input
=
torch
::
randn
({
2
048
,
512
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2
048
,
2
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
512
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
checkCudaErrors
(
cudaSetDevice
(
device
))
;
moe_cuda_forward
(
input
,
gate
,
weight
,
bias
);
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment