Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bf60a846
Commit
bf60a846
authored
Dec 30, 2020
by
Rick Ho
Browse files
limit max streams
parent
ab153b37
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
10 deletions
+21
-10
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+16
-4
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+5
-6
No files found.
pytorch/cuda/cuda_stream_manager.h
View file @
bf60a846
...
...
@@ -6,15 +6,18 @@
#include <helper_cuda.h>
#define MAX_STREAMS 16
struct
CudaStreamManager
{
const
size_t
num_expert
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
CudaStreamManager
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
streams
=
new
cudaStream_t
[
num_expert
];
handles
=
new
cublasHandle_t
[
num_expert
];
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
streams
=
new
cudaStream_t
[
MAX_STREAMS
];
handles
=
new
cublasHandle_t
[
MAX_STREAMS
];
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasSetStream
(
handles
[
i
],
streams
[
i
]));
...
...
@@ -22,11 +25,20 @@ struct CudaStreamManager {
}
~
CudaStreamManager
()
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
streams
[
i
]));
checkCudaErrors
(
cublasDestroy
(
handles
[
i
]));
}
}
inline
cudaStream_t
&
getStream
(
int
idx
)
{
return
streams
[
idx
%
MAX_STREAMS
];
}
inline
cublasHandle_t
&
getHandle
(
int
idx
)
{
return
handles
[
idx
%
MAX_STREAMS
];
}
void
sync
();
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
bf60a846
...
...
@@ -70,7 +70,7 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMemcpyAsync
(
input_buf
+
target_idx
*
in_feat
,
input
+
i
*
in_feat
,
sizeof
(
scalar_t
)
*
in_feat
,
cudaMemcpyDeviceToDevice
,
h
->
s
tream
s
[
gate
[
i
]
]
));
h
->
getS
tream
(
gate
[
i
]
)
));
}
scalar_t
alpha
=
1
,
beta
=
0
;
...
...
@@ -85,7 +85,7 @@ void moe_cuda_forward_impl(
in_feat
);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
h
->
h
andle
s
[
i
]
,
checkCudaErrors
(
cublasXgemm
(
h
->
getH
andle
(
i
)
,
(
transb
==
CUBLAS_OP_T
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
...
...
@@ -108,12 +108,11 @@ void moe_cuda_forward_impl(
output_buf
+
target_idx
*
out_feat
,
sizeof
(
scalar_t
)
*
out_feat
,
cudaMemcpyDeviceToDevice
,
h
->
s
tream
s
[
gate
[
i
]
]
));
h
->
getS
tream
(
gate
[
i
]
)
));
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
cudaStreamSynchronize
(
h
->
streams
[
i
]);
}
h
->
sync
();
cudaFree
(
input_buf
);
cudaFree
(
output_buf
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment