Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
efa510bb
Commit
efa510bb
authored
Dec 15, 2020
by
Jiezhong Qiu
Browse files
can run
parent
15319915
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
0 deletions
+116
-0
pytorch/cuda/CMakeLists.txt
pytorch/cuda/CMakeLists.txt
+24
-0
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+92
-0
No files found.
pytorch/cuda/CMakeLists.txt
0 → 100644
View file @
efa510bb
cmake_minimum_required
(
VERSION 3.0 FATAL_ERROR
)
project
(
moe
)
find_package
(
Torch REQUIRED
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
include_directories
(
"/home/jiezhong/anaconda3/envs/torch/include/python3.6m"
"/usr/local/cuda/include"
)
add_executable
(
moe moe.cpp
)
target_link_libraries
(
moe
"
${
TORCH_LIBRARIES
}
"
)
set_property
(
TARGET moe PROPERTY CXX_STANDARD 14
)
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
if
(
MSVC
)
file
(
GLOB TORCH_DLLS
"
${
TORCH_INSTALL_PREFIX
}
/lib/*.dll"
)
add_custom_command
(
TARGET moe
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
TORCH_DLLS
}
$<TARGET_FILE_DIR:moe>
)
endif
(
MSVC
)
\ No newline at end of file
pytorch/cuda/moe.cpp
0 → 100644
View file @
efa510bb
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdio>
#include <iostream>
#include <vector>
// CUDA runtime
#include <cuda_runtime.h>
#include <cublas_v2.h>
// CUDA and CUBLAS functions
//#include <helper_functions.h>
//#include <helper_cuda.h>
const
int
num_stream
=
1024
;
// std::vector<torch::Tensor>
void
moe_cuda_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x N]
torch
::
Tensor
weight
,
// [N x D_model x D_ffn]
torch
::
Tensor
bias
// [N x D_ffn]
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
num_expert
=
gate
.
size
(
1
);
const
auto
d_model
=
weight
.
size
(
1
);
const
auto
d_ffn
=
weight
.
size
(
2
);
auto
output
=
input
.
new_zeros
({
batch_size
,
num_expert
,
d_ffn
});
cublasHandle_t
handle
;
cublasCreate
(
&
handle
);
cudaStream_t
stream
[
num_stream
];
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
cudaStreamCreate
(
&
stream
[
i
]);
}
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
printf
(
"i=%d j=%d goes to stream %d
\n
"
,
i
,
j
,
s
);
cublasSetStream
(
handle
,
stream
[
s
]);
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Double
)
{
double
alpha
=
1.0
;
double
beta
=
0.0
;
cublasDgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
d_ffn
,
d_model
,
&
alpha
,
input
[
i
].
data_ptr
<
double
>
(),
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
double
>
(),
d_model
,
&
beta
,
output
[
i
][
j
].
data_ptr
<
double
>
(),
1
);
}
else
{
printf
(
"only support double!!!
\n
"
);
}
}
}
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
cudaStreamDestroy
(
stream
[
i
]);
}
cublasDestroy
(
handle
);
}
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
int
main
()
{
torch
::
Tensor
input
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
gate
=
torch
::
ones
({
2
,
1
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
4
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
std
::
cout
<<
input
<<
std
::
endl
;
moe_cuda_forward
(
input
,
gate
,
weight
,
bias
);
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment