# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Keras-based positional embedding layer.""" # pylint: disable=g-classes-have-attributes import math import tensorflow as tf from official.modeling import tf_utils @tf.keras.utils.register_keras_serializable(package="Text") class RelativePositionEmbedding(tf.keras.layers.Layer): """Creates a positional embedding. This layer calculates the position encoding as a mix of sine and cosine functions with geometrically increasing wavelengths. Defined and formulized in "Attention is All You Need", section 3.5. (https://arxiv.org/abs/1706.03762). Args: hidden_size: Size of the hidden layer. min_timescale: Minimum scale that will be applied at each position max_timescale: Maximum scale that will be applied at each position. """ def __init__(self, hidden_size, min_timescale=1.0, max_timescale=1.0e4, **kwargs): # We need to have a default dtype of float32, since the inputs (which Keras # usually uses to infer the dtype) will always be int32. # We compute the positional encoding in float32 even if the model uses # float16, as many of the ops used, like log and exp, are numerically # unstable in float16. if "dtype" not in kwargs: kwargs["dtype"] = "float32" super(RelativePositionEmbedding, self).__init__(**kwargs) self._hidden_size = hidden_size self._min_timescale = min_timescale self._max_timescale = max_timescale def get_config(self): config = { "hidden_size": self._hidden_size, "min_timescale": self._min_timescale, "max_timescale": self._max_timescale, } base_config = super(RelativePositionEmbedding, self).get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs, length=None): """Implements call() for the layer. Args: inputs: An tensor whose second dimension will be used as `length`. If `None`, the other `length` argument must be specified. length: An optional integer specifying the number of positions. If both `inputs` and `length` are spcified, `length` must be equal to the second dimension of `inputs`. Returns: A tensor in shape of [length, hidden_size]. """ if inputs is None and length is None: raise ValueError("If inputs is None, `length` must be set in " "RelativePositionEmbedding().") if inputs is not None: input_shape = tf_utils.get_shape_list(inputs) if length is not None and length != input_shape[1]: raise ValueError( "If inputs is not None, `length` must equal to input_shape[1].") length = input_shape[1] position = tf.cast(tf.range(length), tf.float32) num_timescales = self._hidden_size // 2 min_timescale, max_timescale = self._min_timescale, self._max_timescale log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1)) inv_timescales = min_timescale * tf.exp( tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment) scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( inv_timescales, 0) position_embeddings = tf.concat( [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) return position_embeddings