Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
27fc4689
Commit
27fc4689
authored
Jan 20, 2021
by
Deepak Narayanan
Browse files
Break up tensors sent between pipeline stages into smaller chunks that can be all-gathered
parent
8e922d5b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
4 deletions
+32
-4
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-0
megatron/p2p_communication.py
megatron/p2p_communication.py
+26
-2
megatron/training.py
megatron/training.py
+2
-2
No files found.
megatron/arguments.py
View file @
27fc4689
...
...
@@ -566,6 +566,8 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
'to use.'
)
group
.
add_argument
(
'--scatter-gather-tensors-in-pipeline'
,
action
=
'store_true'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
...
megatron/mpu/__init__.py
View file @
27fc4689
...
...
@@ -59,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/p2p_communication.py
View file @
27fc4689
...
...
@@ -29,20 +29,33 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
args
.
seq_length
*
args
.
micro_batch_size
*
args
.
hidden_size
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_shape
,
tensor_recv_prev
=
torch
.
empty
(
tensor_
chunk_
shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_shape
,
tensor_recv_next
=
torch
.
empty
(
tensor_
chunk_
shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
...
...
@@ -71,6 +84,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
torch
.
cuda
.
synchronize
()
tensor_recv_prev_before
=
tensor_recv_prev
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/training.py
View file @
27fc4689
...
...
@@ -381,11 +381,11 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers
(
'optimizer'
).
start
()
update_successful
l
,
grad_norm
=
optimizer
.
step
()
update_successful
,
grad_norm
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
if
update_successful
l
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment