# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb from megatron.core.transformer.custom_layers.transformer_engine import ( TEDotProductAttention, TELayerNormColumnParallelLinear, TERowParallelLinear, ) from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import divide from .enums import AttnMaskType from .transformer_config import TransformerConfig class Attention(MegatronModule, ABC): """Attention layer abstract class. This layer only contains common modules required for the "self attn" and "cross attn" specializations. """ def __init__( self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding, ): super().__init__(config=config) self.config = config self.layer_number = layer_number self.attn_mask_type = attn_mask_type # For normal attention without groups, num_query_groups == num_attention_heads, # so these two will be the same self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = divide( self.query_projection_size, self.config.num_attention_heads ) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) self.dot_product_attention = TEDotProductAttention( config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type ) self.checkpoint_dot_product_attention = self.config.recompute_granularity == 'selective' # Output. self.linear_proj = TERowParallelLinear( self.query_projection_size, self.config.hidden_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, skip_bias_add=True, ) def _checkpointed_attention_forward( self, query, key, value, attention_mask, rotary_pos_emb=None ): """Forward method with selective activation checkpointing.""" def custom_forward(*inputs): query = inputs[0] key = inputs[1] value = inputs[2] attention_mask = inputs[3] output_ = self.dot_product_attention(query, key, value, attention_mask) return output_ hidden_states = tensor_parallel.checkpoint( custom_forward, False, query, key, value, attention_mask, rotary_pos_emb ) return hidden_states def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): """Allocate memory to store kv cache during inference.""" return torch.empty( inference_max_sequence_length, batch_size, self.num_query_groups_per_partition, self.hidden_size_per_attention_head, dtype=dtype, device=torch.cuda.current_device(), ) def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb): """ Saves the generated key and value tensors to the end of the buffers in inference_params. Returns the full size keys and values from the provided inference_params, as well as adjusted rotary_pos_emb. Returns a tuple: (key, value, rotary_pos_emb) """ if inference_params is None: return key, value, rotary_pos_emb # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= is_first_step = False if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_length = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, key.dtype ) inference_value_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, value.dtype ) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory, ) is_first_step = True else: # Get the pre-allocated buffers for this layer inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ self.layer_number ] batch_start = inference_params.batch_size_offset batch_end = batch_start + key.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # adjust the key rotary positional embedding if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb # need to cross check this condition during inference # if not set_inference_key_value_memory: if not is_first_step: # In inference, we compute one token at a time. # Select the correct positional embedding # (only the last token in the sequence) q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] else: # In the first forward pass of inference, # we use the entire provided prefix. # q_pos_emb here has the rope embeddings of the entire # prefix + to-be-generated output so # we slice to just the prefix. q_pos_emb = q_pos_emb[:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) return key, value, rotary_pos_emb @abstractmethod def get_query_key_value_tensors(self, hidden_states, key_value_states): """ This method needs to be implemented based on whether the derived class is "self-attn" or "cross-attn". """ def forward( self, hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, ): # hidden_states: [sq, b, h] # For self attention we just duplicate the rotary_pos_emb if it isn't already if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): rotary_pos_emb = (rotary_pos_emb,) * 2 # ===================== # Query, Key, and Value # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) # =================================================== # Adjust key, value, and rotary_pos_emb for inference # =================================================== key, value, rotary_pos_emb = self._adjust_key_value_for_inference( inference_params, key, value, rotary_pos_emb ) # ================================================ # relative positional embedding (rotary embedding) # ================================================ if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb query = apply_rotary_pos_emb(query, q_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) # ================================== # core attention computation # ================================== # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] # This is a noop for normal attention where ng == np. When using group query attention this # creates a view that has the keys and values virtually repeated along their dimension to # match the number of queries. if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: key = key.repeat_interleave( self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 ) value = value.repeat_interleave( self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 ) if self.checkpoint_dot_product_attention: core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask) else: core_attn_out = self.dot_product_attention(query, key, value, attention_mask) # ================= # Output. [sq, b, h] # ================= output, bias = self.linear_proj(core_attn_out) return output, bias class SelfAttention(Attention): """Self-attention layer class Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding ): super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type) self.linear_qkv = TELayerNormColumnParallelLinear( self.config.hidden_size, self.query_projection_size + 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False, ) def get_query_key_value_tensors(self, hidden_states, key_value_states=None): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] mixed_qkv, _ = self.linear_qkv(hidden_states) # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, ( (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) * self.hidden_size_per_attention_head ), ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = torch.split( mixed_qkv, [ ( self.num_attention_heads_per_partition // self.num_query_groups_per_partition * self.hidden_size_per_attention_head ), self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, ], dim=3, ) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) return query, key, value class CrossAttention(Attention): """Cross-attention layer class Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding ): super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type) if self.config.num_query_groups != self.config.num_attention_heads: raise ValueError( f"Group query attention is not currently supported in cross attention." ) assert self.query_projection_size == self.kv_projection_size self.linear_q = TELayerNormColumnParallelLinear( self.config.hidden_size, self.query_projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False, ) self.linear_kv = TELayerNormColumnParallelLinear( self.config.hidden_size, 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False, ) def get_query_key_value_tensors(self, hidden_states, key_value_states): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv.size()[:-1] + ( self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv = mixed_kv.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query = query.view(*new_tensor_shape) return query, key, value