Commit b31e1296 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix-inference' into 'main'

Fix inference after T5 pipeline merge

See merge request ADLR/megatron-lm!332
parents cdc614cf f2c35bb0
......@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
......
......@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
......@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape, dtype_=None, timers=None):
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
......@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
return input_tensor
def recv_backward(tensor_shape, timers=None):
def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
......@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
return output_tensor_grad
def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
......@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
timers('forward-send').stop()
def send_backward(input_tensor_grad, tensor_shape, timers=None):
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
......@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
......@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
......@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
......@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
......@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, tensor_shape, timers=None):
recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
......
......@@ -193,7 +193,7 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func, model_type, wrap_with_ddp=True):
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment