Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bf2fd0c0
Commit
bf2fd0c0
authored
Feb 05, 2021
by
Rick Ho
Browse files
support multiple pytorch versions prviate apis
parent
481f5c4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
3 deletions
+25
-3
fmoe/distributed.py
fmoe/distributed.py
+3
-2
fmoe/functions.py
fmoe/functions.py
+2
-1
fmoe/utils.py
fmoe/utils.py
+20
-0
No files found.
fmoe/distributed.py
View file @
bf2fd0c0
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
.utils
import
get_torch_default_comm
class
DistributedGroupedDataParallel
(
nn
.
Module
):
class
DistributedGroupedDataParallel
(
nn
.
Module
):
...
@@ -17,9 +18,9 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -17,9 +18,9 @@ class DistributedGroupedDataParallel(nn.Module):
if
dp_group
is
not
None
:
if
dp_group
is
not
None
:
self
.
comms
[
'dp'
]
=
dp_group
self
.
comms
[
'dp'
]
=
dp_group
else
:
else
:
self
.
comms
[
'dp'
]
=
torch
.
distributed
.
distributed_c10d
.
_get
_default_
group
()
self
.
comms
[
'dp'
]
=
get_
torch_default_
comm
()
if
world_group
is
None
:
if
world_group
is
None
:
self
.
comms
[
'world'
]
=
torch
.
distributed
.
distributed_c10d
.
_get
_default_
group
()
self
.
comms
[
'world'
]
=
get_
torch_default_
comm
()
else
:
else
:
self
.
comms
[
'world'
]
=
world_group
self
.
comms
[
'world'
]
=
world_group
...
...
fmoe/functions.py
View file @
bf2fd0c0
...
@@ -7,6 +7,7 @@ computation.
...
@@ -7,6 +7,7 @@ computation.
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
import
fmoe_cuda
import
fmoe_cuda
from
.utils
import
get_torch_default_comm
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
...
@@ -21,7 +22,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -21,7 +22,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
if
comm
is
None
:
if
comm
is
None
:
comm
=
torch
.
distributed
.
distributed_c10d
.
_get
_default_
group
()
comm
=
get_
torch_default_
comm
()
if
world_size
>
1
:
if
world_size
>
1
:
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
...
...
fmoe/utils.py
0 → 100644
View file @
bf2fd0c0
import
torch.distributed
as
dist
def
get_torch_default_comm
():
try
:
comm
=
dist
.
distributed_c10d
.
_get_default_group
()
return
comm
except
Exception
as
e
:
print
(
'Error {}'
.
format
(
e
))
pass
try
:
comm
=
dist
.
distributed_c10d
.
_default_pg
if
comm
is
not
None
:
return
comm
except
Exception
as
_
:
pass
raise
RuntimeError
(
'Unsupported PyTorch version'
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment