Commit f2c35bb0 authored by Jared Casper's avatar Jared Casper
Browse files

Fix inference after T5 pipeline merge

Adds some backward compatibility code so old inference code still works.
parent 5ac5571b
...@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder: if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \ assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder' 'input_tensor should only be length 1 for stage with both encoder and decoder'
......
...@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
override_scatter_gather_tensors_in_pipeline = False override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
...@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape, dtype_=None, timers=None): def recv_forward(tensor_shape=None, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None): ...@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
return input_tensor return input_tensor
def recv_backward(tensor_shape, timers=None): def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None): ...@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None): def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
...@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None): ...@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, tensor_shape, timers=None): def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None): ...@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, tensor_shape, timers=None): def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None): ...@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None): def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None): ...@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv').start()
...@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non ...@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv').start()
...@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time ...@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev,
recv_next, tensor_shape, timers=None): recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv').start()
......
...@@ -193,7 +193,7 @@ def update_train_iters(args): ...@@ -193,7 +193,7 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters)) print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func, model_type, wrap_with_ddp=True): def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
args.model_type = model_type args.model_type = model_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment