import torch import torch.nn.functional as F from megatron.training import get_args from megatron.core import tensor_parallel from megatron.legacy.model.enums import AttnType from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.transformer import ParallelMLP from megatron.legacy.model.utils import ( erf_gelu, openai_gelu, ) try: from einops import rearrange except ImportError: rearrange = None class ParallelMLPPatch(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, config, is_expert=False): super(ParallelMLP, self).__init__() args = get_args() self.add_bias = config.add_bias_linear ffn_hidden_size = config.ffn_hidden_size if config.gated_linear_unit: ffn_hidden_size *= 2 # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( config.hidden_size, ffn_hidden_size, config=config, init_method=config.init_method, bias=self.add_bias, gather_output=False, skip_bias_add=True, is_expert=is_expert, ) self.bias_gelu_fusion = False self.activation_func = None self.swiglu = args.swiglu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu elif args.swiglu: @torch.compile(mode="max-autotune-no-cudagraphs") def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] self.activation_func = swiglu elif args.squared_relu: def squared_relu(x): return torch.pow(F.relu(x), 2) self.activation_func = squared_relu else: self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu # Project back to h. self.dense_4h_to_h = tensor_parallel.RowParallelLinear( config.ffn_hidden_size, config.hidden_size, config=config, init_method=config.output_layer_init_method, bias=self.add_bias, skip_bias_add=True, input_is_parallel=True, is_expert=is_expert, ) class ParallelAttentionPatch(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None, rotary_pos_emb=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= is_first_step = False if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) is_first_step = True else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== # Query, Key, and Value # ===================== if self.attention_type == AttnType.self_attn: # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_query_groups_per_partition, ( (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) * self.hidden_size_per_attention_head ), ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query_layer, key_layer, value_layer) = torch.split( mixed_x_layer, [ ( self.num_attention_heads_per_partition // self.num_query_groups_per_partition * self.hidden_size_per_attention_head ), self.hidden_size_per_attention_head, self.hidden_size_per_attention_head ], dim=3) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== # duplicate the pos_emb for self attention if rotary_pos_emb is not None: if isinstance(rotary_pos_emb, tuple): rotary_pos_emb = rotary_pos_emb else: rotary_pos_emb = ((rotary_pos_emb,) * 2) if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] # adjust the key rotary positional embedding if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb # need to cross check this condition during inference # if not set_inference_key_value_memory: if not is_first_step: # In inference, we compute one token at a time. # Select the correct positional embedding # (only the last token in the sequence) q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] else: # In the first forward pass of inference, # we use the entire provided prefix. # q_pos_emb here has the rope embeddings of the entire # prefix + to-be-generated output so # we slice to just the prefix. q_pos_emb = q_pos_emb[:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) # ================================== # core attention computation # ================================== # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: key_layer = key_layer.repeat_interleave( self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim = 2 ) value_layer = value_layer.repeat_interleave( self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim = 2 ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if not self.use_flash_attn: if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( query_layer, key_layer, value_layer, attention_mask) else: context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask) else: q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] if not self.sequence_parallel: with tensor_parallel.get_cuda_rng_tracker().fork(): context_layer = self.core_attention_flash(q, k, v) else: context_layer = self.core_attention_flash(q, k, v) context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) return output, bias