Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
7eb40a4a
"vscode:/vscode.git/clone" did not exist on "55477476cbf7c2298c2d5175554085d7a909a7ea"
Commit
7eb40a4a
authored
Dec 15, 2020
by
Jiezhong Qiu
Browse files
update
parent
efa510bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
23 deletions
+37
-23
pytorch/cuda/CMakeLists.txt
pytorch/cuda/CMakeLists.txt
+3
-2
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+34
-21
No files found.
pytorch/cuda/CMakeLists.txt
View file @
7eb40a4a
...
@@ -5,7 +5,8 @@ find_package(Torch REQUIRED)
...
@@ -5,7 +5,8 @@ find_package(Torch REQUIRED)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
include_directories
(
"/home/jiezhong/anaconda3/envs/torch/include/python3.6m"
include_directories
(
"/home/jiezhong/anaconda3/envs/torch/include/python3.6m"
"/usr/local/cuda/include"
)
"/usr/local/cuda-/include"
"/usr/local/cuda/samples/common/inc"
)
add_executable
(
moe moe.cpp
)
add_executable
(
moe moe.cpp
)
target_link_libraries
(
moe
target_link_libraries
(
moe
"
${
TORCH_LIBRARIES
}
"
)
"
${
TORCH_LIBRARIES
}
"
)
...
...
pytorch/cuda/moe.cpp
View file @
7eb40a4a
...
@@ -11,10 +11,10 @@
...
@@ -11,10 +11,10 @@
// CUDA and CUBLAS functions
// CUDA and CUBLAS functions
//#include <helper_functions.h>
//#include <helper_functions.h>
//
#include <helper_cuda.h>
#include <helper_cuda.h>
const
int
num_stream
=
1
024
;
const
int
num_stream
=
1
6
;
// std::vector<torch::Tensor>
// std::vector<torch::Tensor>
void
moe_cuda_forward
(
void
moe_cuda_forward
(
...
@@ -28,48 +28,60 @@ void moe_cuda_forward(
...
@@ -28,48 +28,60 @@ void moe_cuda_forward(
const
auto
d_model
=
weight
.
size
(
1
);
const
auto
d_model
=
weight
.
size
(
1
);
const
auto
d_ffn
=
weight
.
size
(
2
);
const
auto
d_ffn
=
weight
.
size
(
2
);
auto
output
=
input
.
new_zeros
({
batch_size
,
num_expert
,
d_ffn
});
auto
output
=
input
.
new_zeros
({
batch_size
,
num_expert
,
d_ffn
});
std
::
cout
<<
output
<<
std
::
endl
;
cublasHandle_t
handle
;
cublasHandle_t
handle
;
cublasCreate
(
&
handle
);
checkCudaErrors
(
cublasCreate
(
&
handle
)
)
;
cudaStream_t
stream
[
num_stream
];
cudaStream_t
stream
[
num_stream
];
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
cudaStreamCreate
(
&
stream
[
i
]);
checkCudaErrors
(
cudaStreamCreate
(
&
stream
[
i
])
)
;
}
}
size_t
s
;
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
printf
(
"i=%d j=%d goes to stream %d
\n
"
,
i
,
j
,
s
);
printf
(
"i=%d j=%d goes to stream %d
\n
"
,
i
,
j
,
s
);
cublasSetStream
(
handle
,
stream
[
s
]);
cublasSetStream
(
handle
,
stream
[
s
]);
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Double
)
{
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Float
)
{
double
alpha
=
1.0
;
float
alpha
=
1.0
;
double
beta
=
0.0
;
float
beta
=
0.0
;
cublasDgemm
(
handle
,
std
::
cout
<<
input
[
i
]
<<
std
::
endl
;
std
::
cout
<<
weight
.
index
(
gate
[
i
][
j
])
<<
std
::
endl
;
std
::
cout
<<
output
[
i
][
j
]
<<
std
::
endl
;
cublasSgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
1
,
// m
d_ffn
,
d_ffn
,
// n
d_model
,
d_model
,
// k
&
alpha
,
&
alpha
,
input
[
i
].
data_ptr
<
double
>
(),
input
.
data_ptr
<
float
>
()
+
i
*
d_model
,
// input[i].data_ptr<float>(),
1
,
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
double
>
(),
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
float
>
(),
d_model
,
d_model
,
&
beta
,
&
beta
,
output
[
i
][
j
]
.
data_ptr
<
double
>
()
,
output
.
data_ptr
<
float
>
()
+
i
*
num_expert
*
d_ffn
+
j
*
d_ffn
,
1
);
1
);
}
else
{
}
else
{
printf
(
"only support
double
!!!
\n
"
);
printf
(
"only support
float
!!!
\n
"
);
}
}
}
}
}
}
cudaDeviceSynchronize
();
printf
(
"synchronized
\n
"
);
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
cudaStreamDestroy
(
stream
[
i
]);
cudaStreamDestroy
(
stream
[
i
]);
}
}
std
::
cout
<<
output
<<
std
::
endl
;
cublasDestroy
(
handle
);
cublasDestroy
(
handle
);
}
}
...
@@ -83,10 +95,11 @@ void moe_cuda_forward(
...
@@ -83,10 +95,11 @@ void moe_cuda_forward(
int
main
()
{
int
main
()
{
torch
::
Tensor
input
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
int
device
=
2
;
torch
::
Tensor
gate
=
torch
::
ones
({
2
,
1
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
input
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
4
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2
,
1
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat64
).
device
(
torch
::
kCUDA
,
3
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
4
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
4
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
std
::
cout
<<
input
<<
std
::
endl
;
std
::
cout
<<
input
<<
std
::
endl
;
moe_cuda_forward
(
input
,
gate
,
weight
,
bias
);
moe_cuda_forward
(
input
,
gate
,
weight
,
bias
);
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment