"vscode:/vscode.git/clone" did not exist on "143c6615a18cd9dbc1d84a56cbfcbe325fb9ac58"
Commit 9be97c01 authored by dongcl's avatar dongcl
Browse files

update ParallelAttention to core v0.12.0

parent 698bfd4d
......@@ -5,6 +5,7 @@ from functools import wraps
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.core.utils import deprecate_inference_params
from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule
......@@ -92,20 +93,23 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
# query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) =>
# query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
encoder_output=None, inference_context=None,
rotary_pos_emb=None, *, inference_params=None):
# hidden_states: [sq, b, h]
inference_context = deprecate_inference_params(inference_context, inference_params)
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
if inference_context:
if self.layer_number not in inference_context.key_value_memory_dict:
inf_max_seq_len = inference_context.max_sequence_length
inf_max_batch_size = inference_context.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
......@@ -113,12 +117,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_context.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
inference_context.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
......@@ -188,13 +192,14 @@ class ParallelAttentionPatch(MegatronModule):
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
if inference_context:
batch_start = inference_context.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
assert sequence_end <= inference_key_memory.size(0), ("Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length.")
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment