Commit 390ddef8 authored by mshoeybi's avatar mshoeybi
Browse files

added multi-batch inference

parent b6555b71
...@@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward ...@@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args from megatron import get_args
def forward_step(model, tokens, position_ids, attention_mask, class InferenceParams:
set_inference_key_value_memory=False,
inference_max_sequence_len=None): def __init__(self, micro_batch_size_list, max_sequence_len):
assert isinstance(micro_batch_size_list, list)
assert max_sequence_len > 0
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len
self.allocate_key_value_memory = False
self.micro_batch_size_index = 0
def forward_step(model, tokens, position_ids, attention_mask, inference_params):
# Hidden size changes when not using recompute, need to tell p2p_communicate # Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size # functions the correct size
...@@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask, ...@@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask,
# Forward pass through the model. # Forward pass through the model.
model.set_input_tensor(input_tensor) model.set_input_tensor(input_tensor)
output_tensor = model( output_tensor = model(tokens, position_ids, attention_mask,
tokens, position_ids, attention_mask, inference_params=inference_params)
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
send_forward(output_tensor) send_forward(output_tensor)
......
...@@ -25,7 +25,7 @@ from .communication import ( ...@@ -25,7 +25,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage, broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage) broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step from .forward_step import forward_step, InferenceParams
from .sampling import sample from .sampling import sample
...@@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage(
attention_mask, position_ids = _build_attention_mask_and_position_ids( attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens) tokens)
# Set inference params
inference_params = InferenceParams([batch_size], max_sequence_length)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
prev_context_length = 0 prev_context_length = 0
...@@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage(
# If we are starting from scratch, allocate memory for the entire # If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not # context, otherwise set this to false so the memory is not
# reallocated. # reallocated.
set_inference_key_value_memory = (prev_context_length == 0) inference_params.allocate_key_value_memory = \
(prev_context_length == 0)
# Pick the slice that we need to pass through the network. # Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length] tokens2use = tokens[:, prev_context_length:context_length]
...@@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage(
..., prev_context_length:context_length, :context_length] ..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage. # logits will be meanigful only in the last pipeline stage.
logits = forward_step( logits = forward_step(model, tokens2use, positions2use,
model, tokens2use, positions2use, attention_mask2use, attention_mask2use, inference_params)
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=max_sequence_length)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Always the last stage should have an output. # Always the last stage should have an output.
......
...@@ -82,16 +82,13 @@ class GPTModel(MegatronModule): ...@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, tokentype_ids=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
......
...@@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, enc_dec_attn_mask=None, tokentype_ids=None,
set_inference_key_value_memory=False, inference_params=None,
inference_max_sequence_len=None,
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
...@@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule):
encoder_output = self.encoder( encoder_output = self.encoder(
encoder_input, encoder_input,
enc_attn_mask, enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask, dec_attn_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
......
...@@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule): ...@@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule):
skip_bias_add=True) skip_bias_add=True)
# Inference key-value memory # Inference key-value memory
self.inference_key_memory = None self.inference_key_memory_list = None
self.inference_value_memory = None self.inference_value_memory_list = None
self.inference_current_sequence_len = 0 self.inference_current_sequence_len_list = None
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size):
...@@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule): ...@@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
if set_inference_key_value_memory: if inference_params:
assert inference_max_sequence_len and inference_max_sequence_len > 0 if inference_params.allocate_key_value_memory:
self.inference_key_memory = self._allocate_memory( inf_max_seq_len = inference_params.max_sequence_len
inference_max_sequence_len, hidden_states.size(1)) inf_batch_sizes = inference_params.micro_batch_size_list
self.inference_value_memory = self._allocate_memory( self.inference_key_memory_list = [
inference_max_sequence_len, hidden_states.size(1)) self._allocate_memory(inf_max_seq_len, inf_batch_size)
self.inference_current_sequence_len = 0 for inf_batch_size in inf_batch_sizes]
# Some consistency check. self.inference_value_memory_list = [
if inference_max_sequence_len: self._allocate_memory(inf_max_seq_len, inf_batch_size)
assert self.inference_current_sequence_len < \ for inf_batch_size in inf_batch_sizes]
self.inference_key_memory.size(0) self.inference_current_sequence_len_list = [
assert inference_max_sequence_len == \ 0 for _ in inf_batch_sizes]
self.inference_key_memory.size(0) # This is added for safety. In case inference_params
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left # is not provided, make sure there is no potential memory left
# from previous inference. # from previous inference.
if not inference_max_sequence_len: else:
self.inference_key_memory = None self.inference_key_memory_list = None
self.inference_value_memory = None self.inference_value_memory_list = None
self.inference_current_sequence_len_list = None
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
...@@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule): ...@@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule):
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# =================================================== # ==================================
# Adjust key, value, and attention mask for inference # Adjust key and value for inference
# =================================================== # ==================================
if inference_max_sequence_len: if inference_params:
inf_batch_index = inference_params.micro_batch_size_index
assert key_layer.size(1) == \
inference_params.micro_batch_size_list[inf_batch_index]
# Adjust the range variables. # Adjust the range variables.
start = self.inference_current_sequence_len start = self.inference_current_sequence_len_list[inf_batch_index]
self.inference_current_sequence_len += key_layer.size(0) end = start + key_layer.size(0)
end = self.inference_current_sequence_len self.inference_current_sequence_len_list[inf_batch_index] = end
# Copy key and values. # Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer self.inference_key_memory_list[inf_batch_index][start:end, ...] =\
self.inference_value_memory[start:end, ...] = value_layer key_layer
key_layer = self.inference_key_memory[:end, ...] self.inference_value_memory_list[inf_batch_index][start:end, ...] =\
value_layer = self.inference_value_memory[:end, ...] value_layer
key_layer = \
self.inference_key_memory_list[inf_batch_index][:end, ...]
value_layer = \
self.inference_value_memory_list[inf_batch_index][:end, ...]
# =================================== # ===================================
...@@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, enc_dec_attn_mask=None,
enc_dec_attn_mask=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
self.self_attention( self.self_attention(
layernorm_output, layernorm_output,
attention_mask, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule): ...@@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, enc_dec_attn_mask=None,
enc_dec_attn_mask=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Checks. # Checks.
if inference_max_sequence_len: if inference_params:
assert self.activations_checkpoint_method is None, \ assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
...@@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule): ...@@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule):
attention_mask, attention_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment