Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ef83c893
Commit
ef83c893
authored
Jan 04, 2021
by
Jiezhong Qiu
Browse files
stream manager object instead of pointer
parent
e52c0380
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
39 deletions
+47
-39
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+4
-1
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+20
-14
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+22
-23
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+1
-1
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
ef83c893
...
@@ -3,8 +3,10 @@
...
@@ -3,8 +3,10 @@
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
thread_local
CudaStreamManager
*
smgr
=
NULL
;
thread_local
CudaStreamManager
smgr
;
/*
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert, device);
smgr = new CudaStreamManager(num_expert, device);
...
@@ -13,3 +15,4 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int devic
...
@@ -13,3 +15,4 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int devic
assert(smgr->device == device);
assert(smgr->device == device);
return smgr;
return smgr;
}
}
*/
pytorch/cuda/cuda_stream_manager.h
View file @
ef83c893
...
@@ -10,38 +10,44 @@
...
@@ -10,38 +10,44 @@
class
CudaStreamManager
{
class
CudaStreamManager
{
public:
public:
CudaStreamManager
(
const
size_t
num_expert_
,
const
int
device_
)
:
num_expert
(
num_expert_
),
device
(
device_
)
{
CudaStreamManager
()
:
num_expert
(
0
),
device
(
0
),
streams
(
NULL
)
{
/*
Actually, we will see current_device == device,
which means pytorch always sets the correct device for us.
But for safety, we still manually set device to the desired one.
*/
/*
int
current_device
;
int
current_device
;
checkCudaErrors
(
cudaGetDevice
(
&
current_device
));
checkCudaErrors
(
cudaGetDevice
(
&
current_device
));
printf("CudaStreamManager construnctor called, get device %d, set device %d\n", current_device, device);
#ifdef MOE_DEBUG
*/
printf
(
"constructor at device %d
\n
"
,
current_device
);
checkCudaErrors
(
cudaSetDevice
(
device
));
#endif
}
void
setup
(
const
size_t
num_expert
,
const
int
device
)
{
#ifdef MOE_DEBUG
printf
(
"setup at device %d
\n
"
,
device
);
#endif
this
->
num_expert
=
num_expert
;
this
->
device
=
device
;
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
num_expert
];
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
));
checkCudaErrors
(
cublasCreate
(
&
handle
));
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
}
}
}
}
~
CudaStreamManager
()
{
~
CudaStreamManager
()
{
#ifdef MOE_DEBUG
printf
(
"destructor at device %d
\n
"
,
device
);
#endif
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
}
}
checkCudaErrors
(
cublasDestroy
(
handle
));
checkCudaErrors
(
cublasDestroy
(
handle
));
delete
[]
streams
;
delete
[]
streams
;
}
}
const
size_t
num_expert
;
size_t
num_expert
;
const
int
device
;
int
device
;
cublasHandle_t
handle
;
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
cudaStream_t
*
streams
;
};
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
,
const
int
device
);
//
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER
#endif // CUDA_STREAM_MANAGER
pytorch/cuda/moe_cuda_kernel.cu
View file @
ef83c893
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
thread_local
CudaStreamManager
smgr
;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
...
@@ -39,12 +40,9 @@ void moe_cuda_forward_impl(
...
@@ -39,12 +40,9 @@ void moe_cuda_forward_impl(
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
cublasOperation_t
transb
,
cublasOperation_t
transb
)
{
const
int
device
)
{
auto
*
h
=
getCudaStreamManager
(
num_expert
,
device
);
checkCudaErrors
(
cublasSetStream
(
smgr
.
handle
,
*
(
smgr
.
streams
)));
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
// setup Aarray, Barray and Carray
// setup Aarray, Barray and Carray
std
::
vector
<
const
scalar_t
*>
aptrs
;
std
::
vector
<
const
scalar_t
*>
aptrs
;
...
@@ -70,11 +68,11 @@ void moe_cuda_forward_impl(
...
@@ -70,11 +68,11 @@ void moe_cuda_forward_impl(
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
*
(
smgr
.
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
checkCudaErrors
(
cublasXgemmBatched
(
smgr
.
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
transb
,
transb
,
1
,
out_feat
,
in_feat
,
1
,
out_feat
,
in_feat
,
...
@@ -85,7 +83,7 @@ void moe_cuda_forward_impl(
...
@@ -85,7 +83,7 @@ void moe_cuda_forward_impl(
Carray
,
1
,
Carray
,
1
,
batch_size
));
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
)));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
smgr
.
streams
)));
checkCudaErrors
(
cudaFree
(
Aarray
));
checkCudaErrors
(
cudaFree
(
Aarray
));
checkCudaErrors
(
cudaFree
(
Barray
));
checkCudaErrors
(
cudaFree
(
Barray
));
checkCudaErrors
(
cudaFree
(
Carray
));
checkCudaErrors
(
cudaFree
(
Carray
));
...
@@ -100,17 +98,14 @@ void moe_cuda_grad_weight(
...
@@ -100,17 +98,14 @@ void moe_cuda_grad_weight(
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
)
{
const
int
device
)
{
auto
h
=
getCudaStreamManager
(
num_expert
,
device
);
int
*
gate_host
=
new
int
[
batch_size
];
int
*
gate_host
=
new
int
[
batch_size
];
scalar_t
alpha
=
1
,
beta
=
1
;
scalar_t
alpha
=
1
,
beta
=
1
;
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
+
gate_host
[
i
])));
checkCudaErrors
(
cublasSetStream
(
smgr
.
handle
,
*
(
smgr
.
streams
+
gate_host
[
i
])));
checkCudaErrors
(
cublasXgemm
(
h
->
handle
,
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
out_feat
,
out_feat
,
...
@@ -126,7 +121,7 @@ void moe_cuda_grad_weight(
...
@@ -126,7 +121,7 @@ void moe_cuda_grad_weight(
out_feat
));
out_feat
));
}
}
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
+
i
)));
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
smgr
.
streams
+
i
)));
}
}
delete
[]
gate_host
;
delete
[]
gate_host
;
}
}
...
@@ -143,7 +138,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -143,7 +138,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
int
device
=
device_of
(
input
).
value
().
index
();
const
int
device
=
device_of
(
input
).
value
().
index
();
if
(
smgr
.
streams
==
NULL
)
{
smgr
.
setup
(
num_expert
,
device
);
}
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
...
@@ -156,8 +154,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -156,8 +154,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
,
CUBLAS_OP_T
,
CUBLAS_OP_T
device
);
);
}));
}));
...
@@ -178,7 +175,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -178,7 +175,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
int
device
=
device_of
(
input
).
value
().
index
();
const
int
device
=
device_of
(
input
).
value
().
index
();
if
(
smgr
.
streams
==
NULL
)
{
smgr
.
setup
(
num_expert
,
device
);
}
auto
grad_input
=
grad_output
.
new_zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_input
=
grad_output
.
new_zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
...
@@ -193,8 +194,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -193,8 +194,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
out_feat
,
out_feat
,
in_feat
,
in_feat
,
num_expert
,
num_expert
,
CUBLAS_OP_N
,
CUBLAS_OP_N
device
);
);
}));
}));
...
@@ -207,8 +207,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -207,8 +207,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size
,
batch_size
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
device
);
);
}));
}));
...
...
pytorch/cuda/setup.py
View file @
ef83c893
...
@@ -11,7 +11,7 @@ setup(
...
@@ -11,7 +11,7 @@ setup(
name
=
'moe_cuda'
,
name
=
'moe_cuda'
,
sources
=
[
sources
=
[
'moe.cpp'
,
'moe.cpp'
,
'cuda_stream_manager.cpp'
,
#
'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu'
,
'moe_cuda_kernel.cu'
,
],
],
extra_compile_args
=
{
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
)],
extra_compile_args
=
{
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
)],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment