Commit ff2f0a05 authored by mshoeybi's avatar mshoeybi
Browse files

further refactoring

parent 390ddef8
......@@ -15,15 +15,20 @@
"""Forward step utilities."""
import torch
from collections.abc import Iterable
from enum import Enum
from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args
class ForwardStepTypes(Enum):
NO_PIPELINING = 1
class InferenceParams:
def __init__(self, micro_batch_size_list, max_sequence_len):
assert isinstance(micro_batch_size_list, list)
......@@ -31,10 +36,67 @@ class InferenceParams:
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len
self.allocate_key_value_memory = False
self.allocate_key_value_memory = True
self.micro_batch_size_index = 0
class InferenceForwardStep:
def __init__(self, model, batch_size, max_sequence_len):
if isinstance(model, Iterable):
for this_model in model:
this_model.eval()
else:
model.eval()
self.model = model
self.inference_params = InferenceParams([batch_size], max_sequence_len)
self.forward_step_type = ForwardStepTypes.NO_PIPELINING
def __call__(self, tokens, position_ids, attention_mask):
if self.forward_step_type == ForwardStepTypes.NO_PIPELINING:
return self._forward_step_no_pipelining(tokens, position_ids,
attention_mask)
raise Exception('unknown forward step type {}'.format(
self.forward_step_type))
def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask):
# Need to tell p2p_communicate functions the correct size.
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
assert args.seq_length <= self.inference_params.max_sequence_len
args.micro_batch_size = tokens.shape[0]
assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
assert self.inference_params.micro_batch_size_index == 0
# Receive from previous stage.
input_tensor = recv_forward()
# Forward pass through the model.
self.model.set_input_tensor(input_tensor)
output_tensor = self.model(tokens, position_ids, attention_mask,
inference_params=self.inference_params)
# Send output to the next stage.
send_forward(output_tensor)
# Reset the sequence lenght to whatwever it was before.
args.seq_length = orig_seq_length
# Make sure we do not allocate context memory anymore.
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False
return output_tensor
def forward_step(model, tokens, position_ids, attention_mask, inference_params):
# Hidden size changes when not using recompute, need to tell p2p_communicate
......@@ -56,7 +118,3 @@ def forward_step(model, tokens, position_ids, attention_mask, inference_params):
args.seq_length = orig_seq_length
return output_tensor
......@@ -15,7 +15,6 @@
"""Generation utilities."""
import torch
import torch.nn.functional as F
......@@ -25,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step, InferenceParams
from .forward_step import InferenceForwardStep
from .sampling import sample
......@@ -66,6 +65,9 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = InferenceForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
......@@ -109,20 +111,10 @@ def generate_tokens_probs_and_return_on_first_stage(
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
# Set inference params
inference_params = InferenceParams([batch_size], max_sequence_length)
model.eval()
with torch.no_grad():
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not
# reallocated.
inference_params.allocate_key_value_memory = \
(prev_context_length == 0)
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
......@@ -130,8 +122,7 @@ def generate_tokens_probs_and_return_on_first_stage(
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(model, tokens2use, positions2use,
attention_mask2use, inference_params)
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment