Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f2c35bb0
Commit
f2c35bb0
authored
Oct 01, 2021
by
Jared Casper
Browse files
Fix inference after T5 pipeline merge
Adds some backward compatibility code so old inference code still works.
parent
5ac5571b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
10 deletions
+23
-10
megatron/model/language_model.py
megatron/model/language_model.py
+6
-0
megatron/p2p_communication.py
megatron/p2p_communication.py
+16
-9
megatron/training.py
megatron/training.py
+1
-1
No files found.
megatron/model/language_model.py
View file @
f2c35bb0
...
@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
len
(
input_tensor
)
==
1
,
\
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with both encoder and decoder'
'input_tensor should only be length 1 for stage with both encoder and decoder'
...
...
megatron/p2p_communication.py
View file @
f2c35bb0
...
@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
# if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
override_scatter_gather_tensors_in_pipeline
=
False
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
...
@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
,
dtype_
=
None
,
timers
=
None
):
def
recv_forward
(
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
...
@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
return
input_tensor
return
input_tensor
def
recv_backward
(
tensor_shape
,
timers
=
None
):
def
recv_backward
(
tensor_shape
=
None
,
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
...
@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
...
@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
tensor_shape
,
dtype_
=
None
,
timers
=
None
):
def
send_forward
(
output_tensor
,
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
...
@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
...
@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
tensor_shape
,
timers
=
None
):
def
send_backward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
...
@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
timers
(
'backward-send'
).
stop
()
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
None
):
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
...
@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
...
@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
return
output_tensor_grad
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
None
):
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
...
@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
...
@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
return
input_tensor
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
timers
(
'forward-send-forward-recv'
).
start
()
...
@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
...
@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
timers
(
'backward-send-backward-recv'
).
start
()
...
@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time
...
@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
tensor_shape
,
timers
=
None
):
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
...
...
megatron/training.py
View file @
f2c35bb0
...
@@ -193,7 +193,7 @@ def update_train_iters(args):
...
@@ -193,7 +193,7 @@ def update_train_iters(args):
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
,
model_type
,
wrap_with_ddp
=
True
):
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
args
.
model_type
=
model_type
args
.
model_type
=
model_type
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment