Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
15f98a10
Commit
15f98a10
authored
Feb 04, 2021
by
Rick Ho
Browse files
adapt with pytorch 1.8.0 (deprecated 1.6.0)
parent
585604fe
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
5 additions
and
6 deletions
+5
-6
cuda/cuda_stream_manager.cpp
cuda/cuda_stream_manager.cpp
+1
-1
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+0
-1
fmoe/distributed.py
fmoe/distributed.py
+2
-2
fmoe/functions.py
fmoe/functions.py
+1
-1
fmoe/megatron.py
fmoe/megatron.py
+1
-1
No files found.
cuda/cuda_stream_manager.cpp
View file @
15f98a10
...
@@ -18,7 +18,7 @@ class HackNCCLGroup: public c10d::ProcessGroupNCCL {
...
@@ -18,7 +18,7 @@ class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
v
=
getNCCLComm
(
key
,
{
dev
});
auto
v
=
getNCCLComm
(
key
,
{
dev
}
,
c10d
::
OpType
::
ALLTOALL
);
if
(
v
.
size
()
==
0
)
{
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
return
0
;
...
...
cuda/moe_comm_kernel.cu
View file @
15f98a10
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
...
...
fmoe/distributed.py
View file @
15f98a10
...
@@ -17,9 +17,9 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -17,9 +17,9 @@ class DistributedGroupedDataParallel(nn.Module):
if
dp_group
is
not
None
:
if
dp_group
is
not
None
:
self
.
comms
[
'dp'
]
=
dp_group
self
.
comms
[
'dp'
]
=
dp_group
else
:
else
:
self
.
comms
[
'dp'
]
=
torch
.
distributed
.
distributed_c10d
.
_default_
p
g
self
.
comms
[
'dp'
]
=
torch
.
distributed
.
distributed_c10d
.
_
get_
default_g
roup
()
if
world_group
is
None
:
if
world_group
is
None
:
self
.
comms
[
'world'
]
=
torch
.
distributed
.
distributed_c10d
.
_default_
p
g
self
.
comms
[
'world'
]
=
torch
.
distributed
.
distributed_c10d
.
_
get_
default_g
roup
()
else
:
else
:
self
.
comms
[
'world'
]
=
world_group
self
.
comms
[
'world'
]
=
world_group
...
...
fmoe/functions.py
View file @
15f98a10
...
@@ -21,7 +21,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -21,7 +21,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
if
comm
is
None
:
if
comm
is
None
:
comm
=
torch
.
distributed
.
distributed_c10d
.
_default_
p
g
comm
=
torch
.
distributed
.
distributed_c10d
.
_
get_
default_g
roup
()
if
world_size
>
1
:
if
world_size
>
1
:
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
...
...
fmoe/megatron.py
View file @
15f98a10
...
@@ -4,7 +4,7 @@ from .distributed import DistributedGroupedDataParallel
...
@@ -4,7 +4,7 @@ from .distributed import DistributedGroupedDataParallel
def
create_moe_mlp
(
args
,
group
):
def
create_moe_mlp
(
args
,
group
):
assert
(
assert
(
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
args
.
seq_length
*
args
.
micro_
batch_size
%
args
.
tensor_
model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment