Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
267eb9cc
Commit
267eb9cc
authored
May 29, 2023
by
Rick Ho
Browse files
fix processgroupnccl mismatch in pytorch2
parent
9fe65ec0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
1 deletion
+54
-1
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+6
-0
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+10
-0
setup.py
setup.py
+1
-1
tests/test_comm.py
tests/test_comm.py
+37
-0
No files found.
cuda/fmoe_cuda.cpp
View file @
267eb9cc
...
...
@@ -8,6 +8,7 @@
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#include <c10d/ProcessGroupNCCL.hpp>
...
...
@@ -26,7 +27,12 @@ torch::Tensor _global_gather(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void
_ensure_nccl
(
c10d
::
ProcessGroup
&
p
,
torch
::
Tensor
t
);
#else
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
);
#endif // TORCH_VERSION
#endif // FMOE_USE_NCCL
// local_exchange
...
...
cuda/global_exchange.cpp
View file @
267eb9cc
...
...
@@ -100,6 +100,7 @@ torch::Tensor _global_gather(
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#include <c10d/ProcessGroupNCCL.hpp>
...
...
@@ -134,12 +135,21 @@ public:
}
};
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void
_ensure_nccl
(
c10d
::
ProcessGroup
&
p
,
torch
::
Tensor
t
)
{
#else
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
#endif // TORCH_VERSION
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
(
p
.
getBackend
(
c10d
::
ProcessGroup
::
NCCL
).
get
());
#else
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
#endif // TORCH_VERSION
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
...
...
setup.py
View file @
267eb9cc
...
...
@@ -41,7 +41,7 @@ else:
if
__name__
==
'__main__'
:
setuptools
.
setup
(
name
=
'fastmoe'
,
version
=
'1.0.
0
'
,
version
=
'1.0.
1
'
,
description
=
'An efficient Mixture-of-Experts system for PyTorch'
,
author
=
', '
.
join
(
authors
),
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
...
...
tests/test_comm.py
0 → 100644
View file @
267eb9cc
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"n"
,
[
1
,
2
])
def
test_ensure
(
n
):
_run_distributed
(
'_test_ensure'
,
n
,
dict
(),
script
=
__file__
)
def
_test_ensure
():
_ensure_initialized
()
rank
=
torch
.
distributed
.
get_rank
()
x
=
torch
.
rand
(
10
).
cuda
()
ensure_comm
(
x
,
None
)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
_ensure_initialized
()
_test_ensure
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment