Commit 390ddef8 authored by mshoeybi's avatar mshoeybi
Browse files

added multi-batch inference

parent b6555b71
......@@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args
def forward_step(model, tokens, position_ids, attention_mask,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
class InferenceParams:
def __init__(self, micro_batch_size_list, max_sequence_len):
assert isinstance(micro_batch_size_list, list)
assert max_sequence_len > 0
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len
self.allocate_key_value_memory = False
self.micro_batch_size_index = 0
def forward_step(model, tokens, position_ids, attention_mask, inference_params):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
......@@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask,
# Forward pass through the model.
model.set_input_tensor(input_tensor)
output_tensor = model(
tokens, position_ids, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
send_forward(output_tensor)
......
......@@ -25,7 +25,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step
from .forward_step import forward_step, InferenceParams
from .sampling import sample
......@@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage(
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
# Set inference params
inference_params = InferenceParams([batch_size], max_sequence_length)
model.eval()
with torch.no_grad():
prev_context_length = 0
......@@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage(
# If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not
# reallocated.
set_inference_key_value_memory = (prev_context_length == 0)
inference_params.allocate_key_value_memory = \
(prev_context_length == 0)
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
......@@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage(
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(
model, tokens2use, positions2use, attention_mask2use,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=max_sequence_length)
logits = forward_step(model, tokens2use, positions2use,
attention_mask2use, inference_params)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
......
......@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
tokentype_ids=None, inference_params=None):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
inference_params=inference_params)
if self.post_process:
return post_language_model_processing(
......
......@@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
......@@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule):
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
inference_params=inference_params)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
......@@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
inference_params=inference_params)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
......
......@@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule):
skip_bias_add=True)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
self.inference_key_memory_list = None
self.inference_value_memory_list = None
self.inference_current_sequence_len_list = None
def _allocate_memory(self, inference_max_sequence_len, batch_size):
......@@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < \
self.inference_key_memory.size(0)
assert inference_max_sequence_len == \
self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
if inference_params:
if inference_params.allocate_key_value_memory:
inf_max_seq_len = inference_params.max_sequence_len
inf_batch_sizes = inference_params.micro_batch_size_list
self.inference_key_memory_list = [
self._allocate_memory(inf_max_seq_len, inf_batch_size)
for inf_batch_size in inf_batch_sizes]
self.inference_value_memory_list = [
self._allocate_memory(inf_max_seq_len, inf_batch_size)
for inf_batch_size in inf_batch_sizes]
self.inference_current_sequence_len_list = [
0 for _ in inf_batch_sizes]
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
else:
self.inference_key_memory_list = None
self.inference_value_memory_list = None
self.inference_current_sequence_len_list = None
# =====================
# Query, Key, and Value
......@@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule):
query_layer = query_layer.view(*new_tensor_shape)
# ===================================================
# Adjust key, value, and attention mask for inference
# ===================================================
# ==================================
# Adjust key and value for inference
# ==================================
if inference_max_sequence_len:
if inference_params:
inf_batch_index = inference_params.micro_batch_size_index
assert key_layer.size(1) == \
inference_params.micro_batch_size_list[inf_batch_index]
# Adjust the range variables.
start = self.inference_current_sequence_len
self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
start = self.inference_current_sequence_len_list[inf_batch_index]
end = start + key_layer.size(0)
self.inference_current_sequence_len_list[inf_batch_index] = end
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
self.inference_key_memory_list[inf_batch_index][start:end, ...] =\
key_layer
self.inference_value_memory_list[inf_batch_index][start:end, ...] =\
value_layer
key_layer = \
self.inference_key_memory_list[inf_batch_index][:end, ...]
value_layer = \
self.inference_value_memory_list[inf_batch_index][:end, ...]
# ===================================
......@@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method)
def forward(self, hidden_states, attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
......@@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
self.self_attention(
layernorm_output,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
......@@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_max_sequence_len:
if inference_params:
assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing'
......@@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule):
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
inference_params=inference_params)
# Final layer norm.
if self.post_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment