Commit 8b9fe87b authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Inference context optimization

parent f47aa770
...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal ...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output. # Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if get_key_value:
output = [output, presents]
if labels is None: if labels is None:
return output return output
else: else:
...@@ -90,23 +82,22 @@ class GPTModel(MegatronModule): ...@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None,
forward_method_parallel_output=None): set_inference_key_value_memory=False,
inference_max_sequence_len=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
layer_past=layer_past, set_inference_key_value_memory=set_inference_key_value_memory,
get_key_value=get_key_value) inference_max_sequence_len=inference_max_sequence_len)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
get_key_value,
self.parallel_output, self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
return lm_output return lm_output
......
...@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None,
get_key_value=False, pooling_sequence_index=0, set_inference_key_value_memory=False,
inference_max_sequence_len=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
...@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule): ...@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule):
# encoder. # encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input, encoder_output = self.encoder(
enc_attn_mask, encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule): ...@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule):
dec_embedding_output = self.embedding(dec_input_ids, dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids) dec_position_ids)
# decoder # decoder
decoder_output = self.decoder(dec_embedding_output, decoder_output = self.decoder(
dec_attn_mask, dec_embedding_output,
layer_past=layer_past, dec_attn_mask,
get_key_value=get_key_value, encoder_output=encoder_output,
encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
......
...@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule): ...@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule): ...@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states, attention_mask, layer_past=None, # Inference key-value memory
get_key_value=False, encoder_output=None): self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < \
self.inference_key_memory.size(0)
assert inference_max_sequence_len == \
self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
...@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule): ...@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None: # ===================================================
past_key, past_value = layer_past # Adjust key, value, and attention mask for inference
key_layer = torch.cat((past_key.type_as(key_layer), # ===================================================
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer), if inference_max_sequence_len:
value_layer), dim=0) # Adjust the range variables.
if get_key_value: start = self.inference_current_sequence_len
present = (key_layer, value_layer) self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
...@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule): ...@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
...@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule): ...@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule):
output, bias = self.dense(context_layer) output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias return output, bias
...@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None,
layer_past=None, get_key_value=False): enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(
attention_mask, layernorm_output,
layer_past=layer_past, attention_mask,
get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
if get_key_value:
attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
residual, residual,
self.hidden_dropout) self.hidden_dropout)
if get_key_value:
output = [output, presents]
return output return output
...@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule): ...@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule):
forward_step_func""" forward_step_func"""
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Checks. # Checks.
if layer_past is not None: if inference_max_sequence_len:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert self.activations_checkpoint_method is None, \ assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \ 'inference does not work with activation checkpointing'
'activation checkpointing'
if self.pre_process: if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule): ...@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule):
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask)
else: else:
if get_key_value:
presents = []
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
past = None hidden_states = layer(
if layer_past is not None: hidden_states,
past = layer_past[index] attention_mask,
hidden_states = layer(hidden_states, encoder_output=encoder_output,
attention_mask, enc_dec_attn_mask=enc_dec_attn_mask,
encoder_output=encoder_output, set_inference_key_value_memory=set_inference_key_value_memory,
enc_dec_attn_mask=enc_dec_attn_mask, inference_max_sequence_len=inference_max_sequence_len)
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
...@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule): ...@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule):
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
if get_key_value:
output = [output, presents]
return output return output
...@@ -227,8 +227,8 @@ def switch(val1, val2, boolean): ...@@ -227,8 +227,8 @@ def switch(val1, val2, boolean):
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None, set_inference_key_value_memory=False,
forward_method_parallel_output=None): inference_max_sequence_len=None):
# Hidden size changes when not using recompute, need to tell p2p_communicate # Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size # functions the correct size
...@@ -243,20 +243,16 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -243,20 +243,16 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module)) model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(
tokentype_ids=tokentype_ids, tokens, position_ids, attention_mask,
layer_past=layer_past, tokentype_ids=tokentype_ids,
get_key_value=get_key_value, set_inference_key_value_memory=set_inference_key_value_memory,
forward_method_parallel_output=forward_method_parallel_output) inference_max_sequence_len=inference_max_sequence_len)
if get_key_value:
output_tensor, layer_past = output_tensor
send_forward(output_tensor) send_forward(output_tensor)
args.seq_length = orig_seq_length args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor return output_tensor
...@@ -279,7 +275,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -279,7 +275,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
counter = 0 counter = 0
layer_past = None
batch_size = context_tokens.size(0) batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda() is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens tokens = context_tokens
...@@ -296,11 +291,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -296,11 +291,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
while context_length < maxlen: while context_length < maxlen:
types2use = None types2use = None
if counter == 0: if counter == 0:
# Allocate memory for the entire context.
set_inference_key_value_memory = True
tokens2use = tokens[:, :context_length] tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length] positions2use = position_ids[:, :context_length]
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, :context_length]
else: else:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory = False
tokens2use = tokens[:, context_length - 1].view( tokens2use = tokens[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
positions2use = position_ids[:, context_length - 1].view( positions2use = position_ids[:, context_length - 1].view(
...@@ -308,18 +307,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -308,18 +307,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
output, layer_past = forward_step(model, tokens2use,
positions2use, output = forward_step(
attention_mask, model, tokens2use,
layer_past=layer_past, positions2use,
get_key_value=True, attention_mask,
tokentype_ids=types2use, set_inference_key_value_memory=set_inference_key_value_memory,
forward_method_parallel_output=False) inference_max_sequence_len=maxlen,
tokentype_ids=types2use)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
assert output is not None assert output is not None
output = output.float()
logits = output[:, -1].view(batch_size, -1).contiguous() logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
...@@ -331,6 +332,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -331,6 +332,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length started = context_lengths <= context_length
# Clamp the out of vocabulary tokens.
tokenizer = get_tokenizer()
prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
new_tokens = switch( new_tokens = switch(
tokens[:, context_length].view(-1), prev, started) tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
......
...@@ -189,7 +189,7 @@ def update_train_iters(args): ...@@ -189,7 +189,7 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters)) print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func): def get_model(model_provider_func, wrap_with_ddp=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -243,22 +243,24 @@ def get_model(model_provider_func): ...@@ -243,22 +243,24 @@ def get_model(model_provider_func):
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
if args.DDP_impl == 'torch': if wrap_with_ddp:
i = torch.cuda.current_device() if args.DDP_impl == 'torch':
model = [torchDDP(model_module, device_ids=[i], output_device=i, i = torch.cuda.current_device()
process_group=mpu.get_data_parallel_group()) model = [torchDDP(model_module, device_ids=[i], output_device=i,
for model_module in model] process_group=mpu.get_data_parallel_group())
return model for model_module in model]
if args.DDP_impl == 'local': elif args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
return model
else:
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: '
'Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
return model
def get_learning_rate_scheduler(optimizer): def get_learning_rate_scheduler(optimizer):
......
...@@ -65,7 +65,7 @@ if __name__ == "__main__": ...@@ -65,7 +65,7 @@ if __name__ == "__main__":
print("Interleaved pipeline schedule is not yet supported for text generation.") print("Interleaved pipeline schedule is not yet supported for text generation.")
exit() exit()
# Set up model and load checkpoint # Set up model and load checkpoint
model = get_model(model_provider) model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment