Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ea66e5e5
Commit
ea66e5e5
authored
Feb 03, 2021
by
Rick Ho
Browse files
fix ensure device index bug
parent
ae2c434e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
2 deletions
+6
-2
cuda/cuda_stream_manager.cpp
cuda/cuda_stream_manager.cpp
+4
-0
cuda/cuda_stream_manager.h
cuda/cuda_stream_manager.h
+1
-1
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+1
-1
No files found.
cuda/cuda_stream_manager.cpp
View file @
ea66e5e5
...
@@ -57,6 +57,10 @@ void CudaStreamManager::sync(int idx) {
...
@@ -57,6 +57,10 @@ void CudaStreamManager::sync(int idx) {
}
}
void
CudaStreamManager
::
setup
(
const
int
device
)
{
void
CudaStreamManager
::
setup
(
const
int
device
)
{
#ifdef MOE_USE_NCCL
this
->
ncclgood
=
0
;
#endif
this
->
device
=
device
;
checkCudaErrors
(
cudaSetDevice
(
device
));
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
...
...
cuda/cuda_stream_manager.h
View file @
ea66e5e5
...
@@ -29,7 +29,7 @@ public:
...
@@ -29,7 +29,7 @@ public:
#endif
#endif
public:
public:
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
,
ncclgood
(
0
)
{
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
this
->
setup
(
device
);
this
->
setup
(
device
);
}
}
...
...
cuda/moe_comm_kernel.cu
View file @
ea66e5e5
...
@@ -197,7 +197,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -197,7 +197,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
}
}
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
0
);
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
()
);
smgr
->
ensure
((
void
*
)
&
p
,
t
.
device
());
smgr
->
ensure
((
void
*
)
&
p
,
t
.
device
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment