Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c3e688d3
Commit
c3e688d3
authored
Nov 10, 2022
by
Vijay Korthikanti
Committed by
Jared Casper
Nov 10, 2022
Browse files
Support for variable sequence lengths across micro-batches
parent
b4297c6a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
131 additions
and
16 deletions
+131
-16
megatron/arguments.py
megatron/arguments.py
+8
-1
megatron/p2p_communication.py
megatron/p2p_communication.py
+123
-15
No files found.
megatron/arguments.py
View file @
c3e688d3
...
...
@@ -186,6 +186,13 @@ def validate_args(args, defaults={}):
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args
.
variable_seq_lengths
=
False
# Iteration-based training.
if
args
.
train_iters
:
# If we use iteration-based training, make sure the
...
...
@@ -883,7 +890,7 @@ def _add_data_args(parser):
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
help
=
'Maximum sequence length for the biencoder model '
'
for retriever'
)
'
for retriever'
)
group
.
add_argument
(
'--sample-rate'
,
type
=
float
,
default
=
1.0
,
help
=
'sample rate for training data. Supposed to be 0 '
' < sample_rate < 1'
)
...
...
megatron/p2p_communication.py
View file @
c3e688d3
...
...
@@ -8,6 +8,96 @@ from megatron import get_args, core
from
megatron.core
import
mpu
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
args
=
get_args
()
recv_prev_shape_tensor
=
None
recv_next_shape_tensor
=
None
send_prev_shape_tensor
=
None
send_next_shape_tensor
=
None
if
recv_prev
:
recv_prev_shape_tensor
=
torch
.
empty
((
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
recv_next
:
recv_next_shape_tensor
=
torch
.
empty
((
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
tensor_send_prev
is
not
None
:
send_prev_shape_tensor
=
torch
.
tensor
(
tensor_send_prev
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
tensor_send_next
is
not
None
:
send_next_shape_tensor
=
torch
.
tensor
(
tensor_send_next
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
args
.
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
send_prev_shape_tensor
,
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
tensor_recv_next
=
recv_next_shape_tensor
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
send_prev_shape_tensor
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_prev_shape_tensor
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
recv_prev_shape_tensor
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_prev_shape_tensor
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
send_next_shape_tensor
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_next_shape_tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
recv_next_shape_tensor
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_next_shape_tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch
.
cuda
.
synchronize
()
recv_prev_shape
=
[
0
,
0
,
0
]
if
recv_prev_shape_tensor
is
not
None
:
recv_prev_shape
=
recv_prev_shape_tensor
.
tolist
()
recv_next_shape
=
[
0
,
0
,
0
]
if
recv_next_shape_tensor
is
not
None
:
recv_next_shape
=
recv_next_shape_tensor
.
tolist
()
return
recv_prev_shape
,
recv_next_shape
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
tensor_shape
,
dtype_
=
None
):
...
...
@@ -41,21 +131,39 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
not
args
.
variable_seq_lengths
:
if
tensor_shape
is
None
:
recv_prev_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
recv_next_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
else
:
recv_prev_shape
=
tensor_shape
recv_next_shape
=
tensor_shape
else
:
recv_prev_shape
,
recv_next_shape
=
\
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
)
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
recv_prev_chunk_shape
=
reduce
(
operator
.
mul
,
recv_prev_shape
,
1
)
recv_next_chunk_shape
=
reduce
(
operator
.
mul
,
recv_next_shape
,
1
)
if
recv_prev_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
and
\
recv_next_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
recv_prev_chunk_shape
=
recv_prev_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
recv_next_chunk_shape
=
recv_next_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
recv_prev_chunk_shape
=
recv_prev_shape
recv_next_chunk_shape
=
recv_next_shape
override_scatter_gather_tensors_in_pipeline
=
True
else
:
tensor_chunk_shape
=
tensor_shape
recv_prev_chunk_shape
=
recv_prev_shape
recv_next_chunk_shape
=
recv_next_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
...
...
@@ -66,12 +174,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
requires_grad
=
False
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor
_chunk_shape
,
tensor_recv_prev
=
torch
.
empty
(
recv_prev
_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor
_chunk_shape
,
tensor_recv_next
=
torch
.
empty
(
recv_next
_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
...
...
@@ -128,17 +236,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
not
args
.
sequence_parallel
:
if
recv_prev
:
tensor_recv_prev
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor
_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
recv_prev
_shape
).
requires_grad_
()
tensor_recv_prev
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
requires_grad
=
True
,
keep_graph
=
False
)
if
recv_next
:
tensor_recv_next
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor
_shape
).
requires_grad_
()
tensor_recv_next
).
view
(
recv_next
_shape
).
requires_grad_
()
tensor_recv_next
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
requires_grad
=
True
,
keep_graph
=
False
)
return
tensor_recv_prev
,
tensor_recv_next
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment