Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2565f2fa
Commit
2565f2fa
authored
Jan 09, 2021
by
Rick Ho
Browse files
stream manager fixed
parent
b8a212ef
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
24 deletions
+16
-24
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+0
-14
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+14
-8
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+2
-2
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
2565f2fa
/* TODO: make it ke xue
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cassert>
#include <cassert>
#include <thread>
#include <thread>
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
thread_local CudaStreamManager smgr;
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert, device);
}
<<<<<<< HEAD
return smgr;
}
void
CudaStreamManager
::
sync
(
int
i
)
{
void
CudaStreamManager
::
sync
(
int
i
)
{
if
(
i
>
-
1
)
{
if
(
i
>
-
1
)
{
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
...
@@ -25,5 +13,3 @@ void CudaStreamManager::sync(int i) {
...
@@ -25,5 +13,3 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
}
}
}
}
}
*/
pytorch/cuda/cuda_stream_manager.h
View file @
2565f2fa
...
@@ -9,6 +9,12 @@
...
@@ -9,6 +9,12 @@
class
CudaStreamManager
{
class
CudaStreamManager
{
public:
size_t
num_expert
;
int
device
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
public:
public:
CudaStreamManager
()
:
num_expert
(
0
),
device
(
0
),
streams
(
NULL
)
{
CudaStreamManager
()
:
num_expert
(
0
),
device
(
0
),
streams
(
NULL
)
{
int
current_device
;
int
current_device
;
...
@@ -26,10 +32,12 @@ public:
...
@@ -26,10 +32,12 @@ public:
this
->
device
=
device
;
this
->
device
=
device
;
checkCudaErrors
(
cudaSetDevice
(
device
));
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
num_expert
];
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
))
;
handles
=
new
cublasHandle_t
[
num_expert
]
;
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
}
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
}
~
CudaStreamManager
()
{
~
CudaStreamManager
()
{
...
@@ -38,14 +46,12 @@ public:
...
@@ -38,14 +46,12 @@ public:
#endif
#endif
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
checkCudaErrors
(
cudaStreamDestroy
(
*
(
streams
+
i
)));
}
checkCudaErrors
(
cublasDestroy
(
handles
[
i
]));
checkCudaErrors
(
cublasDestroy
(
handle
));
}
delete
[]
streams
;
delete
[]
streams
;
}
}
size_t
num_expert
;
int
device
;
void
sync
(
int
=-
1
);
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
};
};
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
2565f2fa
...
@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
...
@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat
);
in_feat
);
#endif
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
,
// h->getHandle(i),
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
s
[
0
]
,
// h->getHandle(i),
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
out_feat
,
expert_count
[
i
],
in_feat
,
...
@@ -204,7 +204,7 @@ void moe_cuda_grad_weight(
...
@@ -204,7 +204,7 @@ void moe_cuda_grad_weight(
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
// checkCudaErrors(cublasSetStream);
// checkCudaErrors(cublasSetStream);
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
,
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
s
[
0
]
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
out_feat
,
out_feat
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment