Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d899988e
Commit
d899988e
authored
Jan 04, 2021
by
Deepak Narayanan
Browse files
Use batched send and recv instead of torch.distributed.ring_exchange()
parent
2348c99a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
12 deletions
+39
-12
megatron/arguments.py
megatron/arguments.py
+0
-4
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+16
-2
megatron/training.py
megatron/training.py
+21
-6
No files found.
megatron/arguments.py
View file @
d899988e
...
...
@@ -64,10 +64,6 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
if
args
.
pipeline_model_parallel_size
>
1
:
if
"ring_exchange"
not
in
dir
(
torch
.
distributed
):
raise
Exception
(
'PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!'
)
# Checks.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
...
...
megatron/mpu/__init__.py
View file @
d899988e
...
...
@@ -36,6 +36,8 @@ from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
...
...
megatron/mpu/initialize.py
View file @
d899988e
...
...
@@ -276,16 +276,30 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_
firs
t_rank
():
def
get_pipeline_model_parallel_
nex
t_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
...
...
megatron/training.py
View file @
d899988e
...
...
@@ -325,7 +325,7 @@ def setup_model_and_optimizer(model_provider_func):
def
communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_forward
,
recv_backward
):
"""Communicate tensors between stages
using torch.distributed.ring_exchange(.) API
."""
"""Communicate tensors between stages."""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
...
...
@@ -348,11 +348,26 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
dtype
=
dtype
)
# Send tensors in both the forward and backward directions as appropriate.
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
return
tensor_recv_prev
,
tensor_recv_next
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment