Commit 9be97c01 authored by dongcl's avatar dongcl
Browse files

update ParallelAttention to core v0.12.0

parent 698bfd4d
...@@ -5,6 +5,7 @@ from functools import wraps ...@@ -5,6 +5,7 @@ from functools import wraps
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.utils import deprecate_inference_params
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.module import MegatronModule
...@@ -92,20 +93,23 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -92,20 +93,23 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h] Self-attention layer takes input with size [s, b, h]
and returns output of the same size. and returns output of the same size.
""" """
# query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) =>
# query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, encoder_output=None, inference_context=None,
rotary_pos_emb=None): rotary_pos_emb=None, *, inference_params=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
inference_context = deprecate_inference_params(inference_context, inference_params)
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
is_first_step = False is_first_step = False
if inference_params: if inference_context:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_context.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length inf_max_seq_len = inference_context.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_context.max_batch_size
inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition) self.num_query_groups_per_partition)
...@@ -113,12 +117,12 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -113,12 +117,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len, inf_max_batch_size, inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition) self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = ( inference_context.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory) inference_key_memory, inference_value_memory)
is_first_step = True is_first_step = True
else: else:
inference_key_memory, inference_value_memory = \ inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number] inference_context.key_value_memory_dict[self.layer_number]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
...@@ -188,13 +192,14 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -188,13 +192,14 @@ class ParallelAttentionPatch(MegatronModule):
else: else:
rotary_pos_emb = ((rotary_pos_emb,) * 2) rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params: if inference_context:
batch_start = inference_params.batch_size_offset batch_start = inference_context.batch_size_offset
batch_end = batch_start + key_layer.size(1) batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0) sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0), ("Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length.")
# Copy key and values. # Copy key and values.
inference_key_memory[sequence_start:sequence_end, inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer batch_start:batch_end, ...] = key_layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment