Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e36cdd79
Commit
e36cdd79
authored
Jul 26, 2022
by
Mohammad Shoeybi
Browse files
added a flag to be able to switch between pytorch and ring exchange p2p
parent
c464a103
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/p2p_communication.py
megatron/p2p_communication.py
+3
-6
No files found.
megatron/arguments.py
View file @
e36cdd79
...
@@ -761,6 +761,10 @@ def _add_distributed_args(parser):
...
@@ -761,6 +761,10 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--use-ring-exchange-p2p'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
...
megatron/p2p_communication.py
View file @
e36cdd79
...
@@ -23,7 +23,6 @@ from megatron import mpu
...
@@ -23,7 +23,6 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
tensor_shape
,
tensor_shape
,
use_ring_exchange
=
False
,
dtype_
=
None
):
dtype_
=
None
):
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -40,8 +39,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -40,8 +39,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape: shape of tensor to receive (this method assumes that all
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
tensors sent and received in a single function call are
the same shape).
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
dtype_: optional, this is used when the tensor that needs to be
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
communicated is different from args.params_dtype.
Returns:
Returns:
...
@@ -103,7 +100,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -103,7 +100,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
if
args
.
use_ring_exchange
_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_send_next
=
tensor_send_next
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment