Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1e8e455a
Unverified
Commit
1e8e455a
authored
May 30, 2023
by
Rick Ho
Committed by
GitHub
May 30, 2023
Browse files
Merge pull request #157 from laekov/pytorch2-compat
Fix ProcessGroupNCCL mismatch in pytorch2
parents
9fe65ec0
267eb9cc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
1 deletion
+54
-1
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+6
-0
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+10
-0
setup.py
setup.py
+1
-1
tests/test_comm.py
tests/test_comm.py
+37
-0
No files found.
cuda/fmoe_cuda.cpp
View file @
1e8e455a
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#else
#include <c10d/ProcessGroupNCCL.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
...
@@ -26,7 +27,12 @@ torch::Tensor _global_gather(
...
@@ -26,7 +27,12 @@ torch::Tensor _global_gather(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
long
batch_size
,
long
n_workers
);
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void
_ensure_nccl
(
c10d
::
ProcessGroup
&
p
,
torch
::
Tensor
t
);
#else
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
);
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
);
#endif // TORCH_VERSION
#endif // FMOE_USE_NCCL
#endif // FMOE_USE_NCCL
// local_exchange
// local_exchange
...
...
cuda/global_exchange.cpp
View file @
1e8e455a
...
@@ -100,6 +100,7 @@ torch::Tensor _global_gather(
...
@@ -100,6 +100,7 @@ torch::Tensor _global_gather(
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#else
#include <c10d/ProcessGroupNCCL.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
...
@@ -134,12 +135,21 @@ public:
...
@@ -134,12 +135,21 @@ public:
}
}
};
};
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void
_ensure_nccl
(
c10d
::
ProcessGroup
&
p
,
torch
::
Tensor
t
)
{
#else
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
void
_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
#endif // TORCH_VERSION
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
if
(
smgr
->
ncclgood
)
{
return
;
return
;
}
}
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
(
p
.
getBackend
(
c10d
::
ProcessGroup
::
NCCL
).
get
());
#else
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
#endif // TORCH_VERSION
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
smgr
->
ncclgood
=
1
;
...
...
setup.py
View file @
1e8e455a
...
@@ -41,7 +41,7 @@ else:
...
@@ -41,7 +41,7 @@ else:
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
setuptools
.
setup
(
setuptools
.
setup
(
name
=
'fastmoe'
,
name
=
'fastmoe'
,
version
=
'1.0.
0
'
,
version
=
'1.0.
1
'
,
description
=
'An efficient Mixture-of-Experts system for PyTorch'
,
description
=
'An efficient Mixture-of-Experts system for PyTorch'
,
author
=
', '
.
join
(
authors
),
author
=
', '
.
join
(
authors
),
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
...
...
tests/test_comm.py
0 → 100644
View file @
1e8e455a
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"n"
,
[
1
,
2
])
def
test_ensure
(
n
):
_run_distributed
(
'_test_ensure'
,
n
,
dict
(),
script
=
__file__
)
def
_test_ensure
():
_ensure_initialized
()
rank
=
torch
.
distributed
.
get_rank
()
x
=
torch
.
rand
(
10
).
cuda
()
ensure_comm
(
x
,
None
)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
_ensure_initialized
()
_test_ensure
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment