Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
8cff6ad7
Commit
8cff6ad7
authored
Dec 30, 2020
by
Rick Ho
Browse files
tide up C code
parent
2ba58797
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
204 additions
and
134 deletions
+204
-134
pytorch/cuda/.gitignore
pytorch/cuda/.gitignore
+2
-0
pytorch/cuda/CMakeLists.txt
pytorch/cuda/CMakeLists.txt
+18
-4
pytorch/cuda/cublas_wrapper.h
pytorch/cuda/cublas_wrapper.h
+78
-0
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+13
-0
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+31
-0
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+31
-126
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+5
-4
pytorch/cuda/run.sh
pytorch/cuda/run.sh
+25
-0
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+1
-0
No files found.
pytorch/cuda/.gitignore
0 → 100644
View file @
8cff6ad7
*.swp
build
pytorch/cuda/CMakeLists.txt
View file @
8cff6ad7
...
...
@@ -4,10 +4,24 @@ project(moe)
find_package
(
Torch REQUIRED
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
include_directories
(
"/home/jiezhong/miniconda3/include/python3.8"
"/usr/local/cuda/include"
"/usr/local/cuda/samples/common/inc"
)
add_executable
(
moe moe.cpp
)
if
(
NOT PYTHON_INCLUDE
)
set
(
PYTHON_INCLUDE
"/home/jiezhong/miniconda3/include/python3.8"
)
endif
()
if
(
NOT CUDA_HOME
)
set
(
CUDA_HOME
"/usr/local/cuda"
)
endif
()
if
(
NOT CUDA_SAMPLE_INCLUDE
)
set
(
CUDA_SAMPLE_INCLUDE
"/usr/local/cuda/samples/common/inc"
)
endif
()
include_directories
(
"
${
PYTHON_INCLUDE
}
"
"
${
CUDA_HOME
}
/include"
"
${
CUDA_SAMPLE_INCLUDE
}
"
)
add_executable
(
moe moe.cpp cuda_stream_manager.cpp
)
target_link_libraries
(
moe
"
${
TORCH_LIBRARIES
}
"
)
set_property
(
TARGET moe PROPERTY CXX_STANDARD 14
)
...
...
pytorch/cuda/cublas_wrapper.h
0 → 100644
View file @
8cff6ad7
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
Aarray
[],
int
lda
,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
Aarray
[],
int
lda
,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
Aarray
[],
int
lda
,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasSgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
A
,
int
lda
,
const
double
*
B
,
int
ldb
,
const
double
*
beta
,
double
*
C
,
int
ldc
)
{
return
cublasDgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
A
,
int
lda
,
const
__half
*
B
,
int
ldb
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
)
{
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
#endif // CUBLAS_WRAPPER_H
pytorch/cuda/cuda_stream_manager.cpp
0 → 100644
View file @
8cff6ad7
#include <cassert>
#include "cuda_stream_manager.h"
CudaStreamManager
*
smgr
=
NULL
;
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
)
{
if
(
!
smgr
)
{
smgr
=
new
CudaStreamManager
(
num_expert
);
}
assert
(
smgr
->
num_expert
==
num_expert
);
return
smgr
;
}
pytorch/cuda/cuda_stream_manager.h
0 → 100644
View file @
8cff6ad7
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
class
CudaStreamManager
{
public:
CudaStreamManager
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
));
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
}
}
~
CudaStreamManager
()
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
const
size_t
num_expert
;
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
#endif // CUDA_STREAM_MANAGER
pytorch/cuda/moe_cuda_kernel.cu
View file @
8cff6ad7
...
...
@@ -3,7 +3,6 @@
#include <cstdio>
#include <iostream>
#include <vector>
#include <cassert>
#include <cuda.h>
...
...
@@ -13,37 +12,10 @@
// #include "timer.hh"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
class
Helper
{
public:
Helper
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
));
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
}
}
~
Helper
()
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
const
size_t
num_expert
;
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
};
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Helper
*
helper
=
NULL
;
Helper
*
getHelper
(
const
size_t
num_expert
)
{
if
(
!
helper
)
{
helper
=
new
Helper
(
num_expert
);
}
assert
(
helper
->
num_expert
==
num_expert
);
return
helper
;
}
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
template
<
typename
scalar_t
>
...
...
@@ -56,79 +28,6 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
Aarray
[],
int
lda
,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
Aarray
[],
int
lda
,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
Aarray
[],
int
lda
,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasSgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
A
,
int
lda
,
const
double
*
B
,
int
ldb
,
const
double
*
beta
,
double
*
C
,
int
ldc
)
{
return
cublasDgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
A
,
int
lda
,
const
__half
*
B
,
int
ldb
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
)
{
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
...
...
@@ -141,7 +40,7 @@ void moe_cuda_forward_impl(
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
Helper
*
h
=
get
Help
er
(
num_expert
);
auto
*
h
=
get
CudaStreamManag
er
(
num_expert
);
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
...
...
@@ -160,25 +59,29 @@ void moe_cuda_forward_impl(
aptrs
.
push_back
(
input
+
in_feat
*
i
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors
(
cudaMemcpy
(
Carray
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
),
cudaMemcpyHostToDevice
));
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
CUBLAS_OP_N
,
transb
,
1
,
out_feat
,
in_feat
,
&
alpha
,
Aarray
,
1
,
Barray
,
(
transb
==
CUBLAS_OP_T
)
?
out_feat
:
in_feat
,
&
beta
,
Carray
,
1
,
batch_size
));
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors
(
cudaMemcpy
(
Carray
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
),
cudaMemcpyHostToDevice
));
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
CUBLAS_OP_N
,
transb
,
1
,
out_feat
,
in_feat
,
&
alpha
,
Aarray
,
1
,
Barray
,
(
transb
==
CUBLAS_OP_T
)
?
out_feat
:
in_feat
,
&
beta
,
Carray
,
1
,
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
)));
}
...
...
@@ -194,7 +97,7 @@ void moe_cuda_grad_weight(
const
size_t
out_feat
,
const
size_t
num_expert
)
{
Helper
*
h
=
get
Help
er
(
num_expert
);
auto
h
=
get
CudaStreamManag
er
(
num_expert
);
int
*
gate_host
=
new
int
[
batch_size
];
scalar_t
alpha
=
1
,
beta
=
1
;
...
...
@@ -231,7 +134,9 @@ std::vector<torch::Tensor> moe_cuda_forward(
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
...
...
@@ -338,4 +243,4 @@ int main() {
double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
printf("%.3lf TFLOPs\n", tflops);
}
*/
\ No newline at end of file
*/
pytorch/cuda/moe_test.py
View file @
8cff6ad7
from
moe
import
MOELayer
import
torch
import
time
import
sys
def
perf
():
batch_size
=
128
in_feat
=
1024
out_feat
=
4096
num_expert
=
4
batch_size
=
int
(
sys
.
argv
[
1
])
in_feat
=
int
(
sys
.
argv
[
2
])
out_feat
=
int
(
sys
.
argv
[
3
])
num_expert
=
int
(
sys
.
argv
[
4
])
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
...
...
pytorch/cuda/run.sh
0 → 100755
View file @
8cff6ad7
#!/bin/bash
export
PYTHONPATH
=
$PWD
/build/lib.linux-x86_64-3.7
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
if
[
-z
$1
]
then
python moe.py
elif
[
.
$1
=
'.test_all'
]
then
for
bs
in
4 16 64
do
for
inf
in
1024 4096
do
for
ouf
in
1024 4096
do
for
nexp
in
4 16 64
do
echo
$bs
$nexp
${
inf
}
x
${
ouf
}
python moe_test.py
$bs
$inf
$ouf
$nexp
done
done
done
done
else
python
$@
fi
pytorch/cuda/setup.py
View file @
8cff6ad7
...
...
@@ -11,6 +11,7 @@ setup(
name
=
'moe_cuda'
,
sources
=
[
'moe.cpp'
,
'cuda_stream_manager.cpp'
,
'moe_cuda_kernel.cu'
,
],
extra_compile_args
=
{
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
)],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment