Commit f5345dfa authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'inference_context_memory' into 'main'

made model stateless with respect to inference

See merge request ADLR/megatron-lm!348
parents d33460df efc750b6
...@@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule): ...@@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty( return torch.empty(
...@@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule): ...@@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule):
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
if inference_params: if inference_params:
if inference_params.allocate_key_value_memory: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_params.max_batch_size
self.inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size)
self.inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size)
# This is added for safety. In case inference_params inference_params.key_value_memory_dict[self.layer_number] = (
# is not provided, make sure there is no potential memory left inference_key_memory, inference_value_memory)
# from previous inference.
else: else:
self.inference_key_memory = None inference_key_memory, inference_value_memory = \
self.inference_value_memory = None inference_params.key_value_memory_dict[self.layer_number]
# ===================== # =====================
...@@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule): ...@@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule):
if inference_params: if inference_params:
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1) batch_end = batch_start + key_layer.size(1)
assert batch_end <= self.inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0) sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= self.inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0)
# Copy key and values. # Copy key and values.
self.inference_key_memory[sequence_start:sequence_end, inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, batch_start:batch_end, ...] = key_layer
...] = key_layer inference_value_memory[sequence_start:sequence_end,
self.inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
batch_start:batch_end, key_layer = inference_key_memory[
...] = value_layer
key_layer = self.inference_key_memory[
:sequence_end, batch_start:batch_end, ...] :sequence_end, batch_start:batch_end, ...]
value_layer = self.inference_value_memory[ value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...] :sequence_end, batch_start:batch_end, ...]
......
...@@ -40,7 +40,7 @@ class InferenceParams: ...@@ -40,7 +40,7 @@ class InferenceParams:
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.sequence_len_offset = 0 self.sequence_len_offset = 0
self.batch_size_offset = 0 self.batch_size_offset = 0
self.allocate_key_value_memory = True self.key_value_memory_dict = {}
...@@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, ...@@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
# Send output to the next stage. # Send output to the next stage.
send_to_next_pipeline_rank(output_tensor) send_to_next_pipeline_rank(output_tensor)
# Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory:
inference_params.allocate_key_value_memory = False
return output_tensor return output_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment