# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod from .enums import AttnMaskType from .transformer_config import TransformerConfig import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.transformer.core_attention import CoreAttention from megatron.core.utils import divide from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.enums import AttnType, AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.custom_layers.transformer_engine import \ TECoreAttention, TEColumnParallelLinear, TERowParallelLinear class Attention(MegatronModule, ABC): """Attention layer abstract class. This layer only contains common modules required for the "self attn" and "cross attn" specializations. """ def __init__( self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding, ): super().__init__(config=config) self.config = config self.layer_number = layer_number self.attn_mask_type = attn_mask_type self.projection_size = self.config.kv_channels * self.config.num_attention_heads # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = divide(self.projection_size, self.config.num_attention_heads) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.core_attention = TECoreAttention( config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type ) self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' # Output. self.linear_proj = TERowParallelLinear( self.projection_size, self.config.hidden_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, skip_bias_add=True, ) def _checkpointed_attention_forward(self, query, key, value, attention_mask): """Forward method with selective activation checkpointing.""" def custom_forward(*inputs): query = inputs[0] key = inputs[1] value = inputs[2] attention_mask = inputs[3] output_ = self.core_attention(query, key, value, attention_mask) return output_ hidden_states = tensor_parallel.checkpoint( custom_forward, False, query, key, value, attention_mask ) return hidden_states def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device(), ) @abstractmethod def get_query_key_value_tensors(self, hidden_states, key_value_states): """ This method needs to be implemented based on whether the derived class is "self-attn" or "cross-attn". """ def forward(self, hidden_states, attention_mask, key_value_states=None, inference_params=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # @jcasper how should we do inference_params? # can do 1. args, 2. add inference params to TransformerConfig # 3. create another config object 4. something else? if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory, ) else: inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ self.layer_number ] # ===================== # Query, Key, and Value # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) # ================================== # Adjust key and value for inference # ================================== if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # ================================== # core attention computation # ================================== if self.checkpoint_core_attention: core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask) else: core_attn_out = self.core_attention(query, key, value, attention_mask) # ================= # Output. [sq, b, h] # ================= output, bias = self.linear_proj(core_attn_out) return output, bias class SelfAttention(Attention): """Self-attention layer class Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding): super().__init__( config=config, layer_number=layer_number, attn_mask_type=attn_mask_type ) self.linear_qkv = TEColumnParallelLinear( self.config.hidden_size, 3 * self.projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False ) def get_query_key_value_tensors(self, hidden_states, key_value_states=None): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_qkv, _ = self.linear_qkv(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query, key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_qkv, 3) return query, key, value class CrossAttention(Attention): """Cross-attention layer class Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. """ def __init__(self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding): super().__init__( config=config, layer_number=layer_number, attn_mask_type=attn_mask_type ) self.linear_q = TEColumnParallelLinear( self.config.hidden_size, self.projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False ) self.linear_kv = TEColumnParallelLinear( self.config.hidden_size, 2 * self.projection_size, config=self.config, init_method=self.config.init_method, bias=self.config.add_bias_linear, skip_bias_add=False ) def get_query_key_value_tensors(self, hidden_states, key_value_states): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv.size()[:-1] + ( self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv = mixed_kv.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query = query.view(*new_tensor_shape) return query, key, value