Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import math
import gin
import tensorflow as tf
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TalkingHeadsAttention(tf.keras.layers.Layer):
"""Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
dropout: Dropout probability.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self,
num_heads,
head_size,
dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(TalkingHeadsAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def build(self, input_shape):
super(TalkingHeadsAttention, self).build(input_shape)
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
self._post_softmax_weight = self.add_weight(
"post_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"head_size":
self._head_size,
"dropout_rate":
self._dropout_rate,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(TalkingHeadsAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Apply talking heads before softmax.
attention_scores = tf.einsum("BNFT,NL->BLFT", attention_scores,
self._pre_softmax_weight)
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# Apply talking heads after softmax.
attention_probs = tf.einsum("BNFT,NL->BLFT", attention_probs,
self._post_softmax_weight)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
This diff is collapsed.
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -142,10 +143,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -142,10 +143,8 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
name="intermediate") name="intermediate")
# Use float32 in intermediate gelu activation for numeric stability.
# TODO(b/149117297): investigate gelu numeric stability.
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=tf.float32) self._intermediate_activation)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -221,3 +220,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -221,3 +220,10 @@ class Transformer(tf.keras.layers.Layer):
layer_output = self._output_layer_norm(layer_output + attention_output) layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True)
def call(self, inputs):
return super(CompiledTransformer, self).call(inputs)
...@@ -122,7 +122,8 @@ class AlbertTransformerEncoder(network.Network): ...@@ -122,7 +122,8 @@ class AlbertTransformerEncoder(network.Network):
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
use_dynamic_slicing=True, use_dynamic_slicing=True,
max_sequence_length=max_sequence_length) max_sequence_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = ( type_embeddings = (
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment