# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer.""" from contextlib import nullcontext from typing import Callable, Optional, Tuple, Union import os from keras import backend, layers, initializers import tensorflow as tf from transformer_engine.tensorflow.module import ( LayerNorm, LayerNormDense, LayerNormMLP, Dense, ) from transformer_engine.tensorflow.softmax import FusedScaleMaskSoftmax from transformer_engine.tensorflow.constants import ( AttnMaskTypes, AttnTypes, LayerTypes, ) from transformer_engine.tensorflow.utils import ( divide, attention_mask_func, ) from transformer_engine.tensorflow.jit import ( get_bias_dropout_add, bias_dropout_add_fused_train, bias_dropout_add_fused_inference, ) class CoreAttention(tf.keras.Model): # pylint: disable=too-few-public-methods """Parallel attention w/o QKV and Proj Gemms BMM1 -> softmax + dropout -> BMM2 """ def __init__( self, num_attention_heads: int, kv_channels: int, attention_dropout: float, layer_number: Optional[int] = None, apply_query_key_layer_scaling: bool = True, attention_softmax_in_fp32: bool = False, attn_mask_type: str = "causal", ) -> None: super().__init__() self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.attention_softmax_in_fp32 = attention_softmax_in_fp32 if layer_number is None: self.apply_query_key_layer_scaling = False else: self.layer_number = max(1, layer_number) if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.attn_mask_type = attn_mask_type projection_size = kv_channels * num_attention_heads assert ( attn_mask_type in AttnMaskTypes ), f"attn_mask_type {attn_mask_type} not supported" # Per attention head and per partition values. self.hidden_size_per_partition = divide(projection_size, 1) self.hidden_size_per_attention_head = divide( projection_size, num_attention_heads ) self.attention_dropout_ctx = nullcontext coeff = None self.norm_factor = tf.math.sqrt( float(self.hidden_size_per_attention_head)) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.attn_mask_type, attention_mask_func, self.attention_softmax_in_fp32, coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = layers.Dropout(attention_dropout) def __call__( self, query_layer: tf.Tensor, key_layer: tf.Tensor, value_layer: tf.Tensor, attention_mask: tf.Tensor, ) -> tf.Tensor: """core attention fprop""" # [b, np, sq, sk] output_size = ( query_layer.shape[1], query_layer.shape[2], query_layer.shape[0], key_layer.shape[0], ) # [sq, b, np, hn] -> [sq, b * np, hn] new_q_shape = (output_size[2], output_size[0] * output_size[1], -1) query_layer = tf.reshape(query_layer, new_q_shape) # [sk, b, np, hn] -> [sk, b * np, hn] new_k_shape = (output_size[3], output_size[0] * output_size[1], -1) key_layer = tf.reshape(key_layer, new_k_shape) norm_factor = self._maybe_cast_inputs(self.norm_factor) # Raw attention scores. [b * np, sq, sk] matmul_result = ( tf.matmul( tf.transpose(query_layer, perm=(1, 0, 2)), # [b * np, sq, hn] tf.transpose(key_layer, perm=(1, 2, 0)), # [b * np, hn, sk] ) / norm_factor ) # change view to [b, np, sq, sk] attention_scores = tf.reshape(matmul_result, output_size) # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): attention_probs = self.attention_dropout(attention_probs) # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] output_size = ( value_layer.shape[1], value_layer.shape[2], query_layer.shape[0], value_layer.shape[3], ) # change view [sk, b * np, hn] new_v_shape = (value_layer.shape[0], output_size[0] * output_size[1], -1) value_layer = tf.reshape(value_layer, new_v_shape) # change view [b * np, sq, sk] new_attn_shape = (output_size[0] * output_size[1], output_size[2], -1) attention_probs = tf.reshape(attention_probs, new_attn_shape) # matmul: [b * np, sq, hn] context_layer = tf.matmul( attention_probs, # [b * np, sq, sk] tf.transpose(value_layer, perm=(1, 0, 2)), # [b * np, sk, hn] ) # change view [b, np, sq, hn] context_layer = tf.reshape(context_layer, output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = tf.transpose(context_layer, perm=(2, 0, 1, 3)) # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = ( *context_layer.shape[:-2], self.hidden_size_per_partition, ) context_layer = tf.reshape(context_layer, new_context_layer_shape) return context_layer class MultiHeadAttention(layers.Layer): """Parallel attention w/ QKV and Proj Gemms BMM1 -> softmax + dropout -> BMM2 """ def __init__( self, hidden_size: int, num_attention_heads: int, kv_channels: int, attention_dropout: float, layernorm_epsilon: float = 1e-3, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, apply_query_key_layer_scaling: bool = True, attention_softmax_in_fp32: bool = False, attn_mask_type: str = "causal", return_layernorm_output: bool = False, input_layernorm: bool = False, attention_type: str = "self", fuse_qkv_params: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size self.layer_number = (layer_number,) self.input_layernorm = input_layernorm self.attention_type = attention_type self.return_layernorm_output = return_layernorm_output self.init_method = init_method self.fuse_qkv_params = fuse_qkv_params # We only support zero-initializer for bias weights. self.bias_initializer = initializers.get("zeros") assert ( attention_type in AttnTypes ), f"attention_type {attention_type} not supported" self.hidden_size_per_attention_head = kv_channels self.num_attention_heads_per_partition = divide(num_attention_heads, 1) if self.attention_type == "self": if self.input_layernorm: self.layernorm_qkv = LayerNormDense( 3 * hidden_size, epsilon=layernorm_epsilon, kernel_initializer=init_method, use_bias=True, return_bias=False, return_layernorm_output=return_layernorm_output, skip_weight_param_allocation=not fuse_qkv_params, ) else: self.qkv = Dense( 3 * hidden_size, kernel_initializer=init_method, use_bias=True, return_bias=False, skip_weight_param_allocation=not fuse_qkv_params, ) else: if self.input_layernorm: self.layernorm_query = LayerNormDense( hidden_size, epsilon=layernorm_epsilon, kernel_initializer=init_method, use_bias=True, return_bias=False, return_layernorm_output=return_layernorm_output, skip_weight_param_allocation=not fuse_qkv_params, ) else: self.query_layer = Dense( hidden_size, kernel_initializer=init_method, use_bias=True, return_bias=False, skip_weight_param_allocation=not fuse_qkv_params, ) self.key_value = Dense( 2 * hidden_size, kernel_initializer=init_method, use_bias=True, return_bias=False, skip_weight_param_allocation=not fuse_qkv_params, ) # Core Self attention. self.core_attention = CoreAttention( num_attention_heads, kv_channels, attention_dropout, layer_number=layer_number, apply_query_key_layer_scaling=apply_query_key_layer_scaling, attention_softmax_in_fp32=attention_softmax_in_fp32, attn_mask_type=attn_mask_type, ) # Linear self.proj = Dense( hidden_size, kernel_initializer=output_layer_init_method, use_bias=False, return_bias=True, ) def build(self, input_shape): """One-time allocation of the variables.""" input_shape = tf.TensorShape(input_shape) last_dim = tf.compat.dimension_value(input_shape[-1]) if last_dim is None: raise ValueError( "The last dimension of the inputs to a Dense layer should be " f"defined. Found None. Full input shape received: {input_shape}" ) if not self.fuse_qkv_params: self.set_qkv_params( last_dim, 3 * self.hidden_size, use_bias=True, ) def set_qkv_params( self, in_features, out_features, use_bias: bool = False, ) -> None: """Initialize separate Parameters for query, key, and value tensors.""" assert ( out_features % 3 == 0 ), f"3 way QKV split with dimension {out_features} not possible." qkv_dim = out_features // 3 if self.attention_type == "self": self.qkv_weight = self.add_weight( name="qkv_kernel", shape=(in_features, out_features), initializer=self.init_method, trainable=True, ) self.qkv_bias = None if use_bias: self.qkv_bias = self.add_weight( name="qkv_bias", shape=(out_features,), initializer=self.bias_initializer, trainable=True, ) else: self.q_weight = self.add_weight( name="q_kernel", shape=(in_features, qkv_dim), initializer=self.init_method, trainable=True, ) self.kv_weight = self.add_weight( name="kv_kernel", shape=(in_features, 2 * qkv_dim), initializer=self.init_method, trainable=True, ) self.q_bias = None self.kv_bias = None if use_bias: self.q_bias = self.add_weight( name="q_bias", shape=(qkv_dim,), initializer=self.bias_initializer, trainable=True, ) self.kv_bias = self.add_weight( name="kv_bias", shape=(2 * qkv_dim,), initializer=self.bias_initializer, trainable=True, ) def _get_training_value(self, training=None): if training is None: training = backend.learning_phase() if isinstance(training, int): training = bool(training) if not self.trainable: # When the layer is not trainable, it overrides the value passed # from model. training = False return training def call( self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, encoder_output: Optional[tf.Tensor] = None, training: bool = None, ) -> Tuple[Union[tf.Tensor, None], ...]: """MultiHeadAttention FWD""" training = self._get_training_value(training) # hidden_states: [sq, b, h] if attention_mask is not None: assert ( attention_mask.dtype == tf.bool ), "Attention mask must be a boolean tensor" # ===================== # Query, Key, and Value # ===================== if self.attention_type == "self": qkv_weight = self.qkv_weight if not self.fuse_qkv_params else None qkv_bias = self.qkv_bias if not self.fuse_qkv_params else None # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, kernel=qkv_weight, bias=qkv_bias, training=training, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs else: mixed_x_layer = layernorm_qkv_outputs else: mixed_x_layer = self.qkv( hidden_states, kernel=qkv_weight, bias=qkv_bias, training=training, ) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = ( *mixed_x_layer.shape[:-1], self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) mixed_x_layer = tf.reshape(mixed_x_layer, new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] query_layer, key_layer, value_layer = tf.split( mixed_x_layer, num_or_size_splits=3, axis=-1 ) else: kv_weight = self.kv_weight if not self.fuse_qkv_params else None kv_bias = self.kv_bias if not self.fuse_qkv_params else None # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer = self.key_value( encoder_output, kernel=kv_weight, bias=kv_bias, training=training, ) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = ( *mixed_kv_layer.shape[:-1], self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv_layer = tf.reshape(mixed_kv_layer, new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] key_layer, value_layer = tf.split( mixed_kv_layer, num_or_size_splits=2, axis=-1 ) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: layernorm_query_outputs = self.layernorm_query( hidden_states, kernel=self.q_weight, bias=self.q_bias, training=training, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs else: query_layer = layernorm_query_outputs else: query_layer = self.query_layer( hidden_states, kernel=self.q_weight, bias=self.q_bias, training=training, ) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = ( *query_layer.shape[:-1], self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query_layer = tf.reshape(query_layer, new_tensor_shape) # ================================== # core attention computation # ================================== context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask ) # ================= # Output. [sq, b, h] # ================= attention_output, attention_bias = self.proj( context_layer, training=training, ) if self.input_layernorm and self.return_layernorm_output: return attention_output, attention_bias, layernorm_output return attention_output, attention_bias class DropPath(tf.keras.Model): # pylint: disable=too-few-public-methods """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0.0) -> None: super().__init__() self.drop_prob = drop_prob def __call__(self, hidden_state: tf.Tensor, training: bool) -> tf.Tensor: """DropPath FWD""" if self.drop_prob == 0.0 or not training: return hidden_state keep_prob = 1 - self.drop_prob # work with diff dim tensors, not just 2D ConvNets shape = (hidden_state.shape[0],) + (1,) * (len(hidden_state.shape) - 1) # TODO(kaixih): We set the seed mainly for debugging purpose. Should # allow users to turn it off. random_tensor = tf.random.stateless_uniform(shape, seed=[1, 0]) random_mask = tf.cast(random_tensor <= keep_prob, dtype=hidden_state.dtype) output = (hidden_state / keep_prob) * random_mask return output class TransformerLayer(tf.keras.Model): # pylint: disable=too-few-public-methods """ TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper "Attention Is All You Need". Parameters ---------- hidden_size : int size of each input sample. ffn_hidden_size : int intermediate size to which input samples are projected. num_attention_heads : int number of attention heads in the transformer layer. layernorm_epsilon : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. hidden_dropout: float, default = 0.1 dropout probability for the dropout op after FC2 layer. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. init_method : Callable, default = `None` used for initializing weights of QKV and FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`. output_layer_init_method : Callable, default = `None` used for initializing weights of PROJ and FC2 in the following way: `output_layer_init_method(weight)`. When set to `None`, defaults to `tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`. apply_residual_connection_post_layernorm : bool, default = `False` if set to `True`, residual connections are taken from the output of layer norm (default is taken from input of layer norm) layer_number: int, default = `None` layer number of the current `TransformerLayer` when multiple such modules are concatenated to form a transformer block. apply_query_key_layer_scaling: bool, default = `True` apply query-key layer scaling during BMM1 by a factor of `layer_number` output_layernorm: bool, default = `False` if set to `True`, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation. attention_softmax_in_fp32: bool, default = `False` if set to `True`, softmax is executed in tf.float32 dtype (single precision) layer_type: {'encoder', 'decoder'}, default = `encoder` if set to `decoder`, an additional cross-attn block is added after self-attn. This can be used for structures like `T5` Transformer in conjunction with the `encoder` option. kv_channels: int, default = `None` number of key-value channels. defaults to `hidden_size / num_attention_heads` if `None`. self_attn_mask_type: {'causal', 'padding'}, default = `causal` type of attention mask passed into softmax operation. Optimization parameters ----------------------- drop_path_rate: float, default = 0.0 when > 0.0, applies stochastic depth per sample in the main path of the residual block. fuse_qkv_params: bool, default = 'False' if set to `True`, `TransformerLayer` module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits. """ def __init__( self, hidden_size: int, ffn_hidden_size: int, num_attention_heads: int, layernorm_epsilon: float = 1e-5, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", apply_query_key_layer_scaling: bool = True, attention_softmax_in_fp32: bool = False, apply_residual_connection_post_layernorm: bool = False, output_layernorm: bool = False, layer_type: str = "encoder", drop_path_rate: float = 0.0, fuse_qkv_params: bool = False, ) -> None: super().__init__() bias_dropout_fusion = \ bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number self.output_layernorm = output_layernorm self.layer_type = layer_type self.apply_residual_connection_post_layernorm = ( apply_residual_connection_post_layernorm ) assert ( self_attn_mask_type in AttnMaskTypes ), f"self_attn_mask_type {self_attn_mask_type} not supported" assert layer_type in LayerTypes, \ f"layer_type {layer_type} not supported" self.kv_channels = ( kv_channels if kv_channels else (hidden_size // num_attention_heads) ) if init_method is None: init_method = initializers.RandomNormal(mean=0.0, stddev=0.023) if output_layer_init_method is None: output_layer_init_method = initializers.RandomNormal(mean=0.0, stddev=0.023) attention_args = ( hidden_size, num_attention_heads, self.kv_channels, attention_dropout, layernorm_epsilon, init_method, output_layer_init_method, ) common_attention_kwargs = { "layer_number": layer_number, "apply_query_key_layer_scaling": apply_query_key_layer_scaling, "attention_softmax_in_fp32": attention_softmax_in_fp32, "return_layernorm_output": apply_residual_connection_post_layernorm, "fuse_qkv_params": fuse_qkv_params, } self.self_attention = MultiHeadAttention( *attention_args, **common_attention_kwargs, attn_mask_type=self_attn_mask_type, input_layernorm=not output_layernorm, attention_type="self", ) if layer_type == "decoder": self.inter_attention = MultiHeadAttention( *attention_args, **common_attention_kwargs, attn_mask_type="padding", input_layernorm=True, attention_type="cross", ) # LayerNorm -> gelu(Linear + Bias) -> Linear self.layernorm_mlp = LayerNormMLP( hidden_size, ffn_hidden_size, epsilon=layernorm_epsilon, kernel_initializer=init_method, ffn_kernel_initializer=output_layer_init_method, use_bias=False, return_bias=True, return_layernorm_output=apply_residual_connection_post_layernorm, ) self.hidden_dropout = hidden_dropout self.bias_dropout_fusion = bias_dropout_fusion self.drop_path = (DropPath(drop_path_rate) if drop_path_rate > 0.0 else None) if self.output_layernorm: self.layernorm = LayerNorm( epsilon=layernorm_epsilon, ) def _get_training_value(self, training=None): if training is None: training = backend.learning_phase() if isinstance(training, int): training = bool(training) if not self.trainable: # When the layer is not trainable, it overrides the value passed # from model. training = False return training def __call__( self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, encoder_output: Optional[tf.Tensor] = None, enc_dec_attn_mask: Optional[tf.Tensor] = None, training: bool = None, ) -> tf.Tensor: """ Transformer Layer: attention block and a feedforward network (MLP) Parameters ---------- hidden_states : tf.Tensor Input tensor. attention_mask : tf.Tensor Boolean tensor used to mask out self-attention softmax input. encoder_output : tf.Tensor Output of the encoder block to be fed into the decoder block if using `layer_type="decoder"`. enc_dec_attn_mask : tf.Tensor Boolean tensor used to mask out inter-attention softmax input if using `layer_type="decoder"`. """ if attention_mask is not None: assert ( attention_mask.dtype == tf.bool ), "Attention mask must be a boolean tensor" # Theoretically, the input dtype can be handled by the autocast during # the layer call. However, we may use the input (hidden_states) in the # residual connection before the layer is called. So, we convert it # ahead of time. As for the other input (encoder_output), we can leave # the conversion to the inter_attention layer, since it won't be used in # the residual connection. hidden_states = self._maybe_cast_inputs(hidden_states) # Self attention. self_attention_outputs = self.self_attention( hidden_states, attention_mask, training=training, ) if (self.apply_residual_connection_post_layernorm and not self.output_layernorm): attention_output, attention_bias, residual = self_attention_outputs else: attention_output, attention_bias = self_attention_outputs residual = hidden_states # Set BDA func. if self.bias_dropout_fusion: if training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(training) # Bias dropout add. attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype) if self.drop_path is None: bda_output = bias_dropout_add_func( attention_output, attention_bias, residual, self.hidden_dropout, ) else: # TODO(kaixih): Use stateless_dropout and specify the seed # mainly for debugging purpose. Should allow random seed. out = ( tf.nn.experimental.stateless_dropout( attention_output + attention_bias, rate=self.hidden_dropout, seed=[1, 0], ) if training else attention_output + attention_bias ) bda_output = residual + self.drop_path(out, training) # Cross attention. if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( bda_output, enc_dec_attn_mask, encoder_output=encoder_output, training=training, ) if self.apply_residual_connection_post_layernorm: attention_output, attention_bias, residual = \ inter_attention_outputs else: attention_output, attention_bias = inter_attention_outputs residual = bda_output attention_bias = tf.cast(attention_bias, dtype=self.compute_dtype) bda_output = bias_dropout_add_func( attention_output, attention_bias, residual, self.hidden_dropout, ) # MLP. mlp_outputs = self.layernorm_mlp( bda_output, training=training, ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs else: mlp_output, mlp_bias = mlp_outputs residual = bda_output # Bias dropout add. mlp_bias = tf.cast(mlp_bias, dtype=self.compute_dtype) if self.drop_path is None: output = bias_dropout_add_func( mlp_output, mlp_bias, residual, self.hidden_dropout, ) else: # TODO(kaixih): Use stateless_dropout and specify the seed # mainly for debugging purpose. Should allow random seed. output = ( tf.nn.experimental.stateless_dropout( mlp_output + mlp_bias, rate=self.hidden_dropout, seed=[1, 0], ) if training else mlp_output + mlp_bias ) output = residual + self.drop_path(output, training) # For BERT like architectures. if self.output_layernorm: output = self.layernorm(output) # output: [b, s, h] return output