Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f2c35bb0
Commit
f2c35bb0
authored
Oct 01, 2021
by
Jared Casper
Browse files
Fix inference after T5 pipeline merge
Adds some backward compatibility code so old inference code still works.
parent
5ac5571b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
10 deletions
+23
-10
megatron/model/language_model.py
megatron/model/language_model.py
+6
-0
megatron/p2p_communication.py
megatron/p2p_communication.py
+16
-9
megatron/training.py
megatron/training.py
+1
-1
No files found.
megatron/model/language_model.py
View file @
f2c35bb0
...
...
@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with both encoder and decoder'
...
...
megatron/p2p_communication.py
View file @
f2c35bb0
...
...
@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
...
...
@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
,
dtype_
=
None
,
timers
=
None
):
def
recv_forward
(
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
return
input_tensor
def
recv_backward
(
tensor_shape
,
timers
=
None
):
def
recv_backward
(
tensor_shape
=
None
,
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
return
output_tensor_grad
def
send_forward
(
output_tensor
,
tensor_shape
,
dtype_
=
None
,
timers
=
None
):
def
send_forward
(
output_tensor
,
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
...
...
@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
tensor_shape
,
timers
=
None
):
def
send_backward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
...
...
@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
None
):
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
None
):
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
...
...
@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
...
...
@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
...
...
@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
tensor_shape
,
timers
=
None
):
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
...
...
megatron/training.py
View file @
f2c35bb0
...
...
@@ -193,7 +193,7 @@ def update_train_iters(args):
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
,
model_type
,
wrap_with_ddp
=
True
):
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
args
=
get_args
()
args
.
model_type
=
model_type
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment