position_embedding.py 3.98 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based positional embedding layer."""
16
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
17

18
19
import math

Hongkun Yu's avatar
Hongkun Yu committed
20
21
22
23
24
import tensorflow as tf

from official.modeling import tf_utils


25
26
27
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf.keras.layers.Layer):
  """Creates a positional embedding.
28

29
30
31
32
  This layer calculates the position encoding as a mix of sine and cosine
  functions with geometrically increasing wavelengths. Defined and formulized in
   "Attention is All You Need", section 3.5.
  (https://arxiv.org/abs/1706.03762).
33

34
  Args:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    hidden_size: Size of the hidden layer.
    min_timescale: Minimum scale that will be applied at each position
    max_timescale: Maximum scale that will be applied at each position.
  """

  def __init__(self,
               hidden_size,
               min_timescale=1.0,
               max_timescale=1.0e4,
               **kwargs):
    # We need to have a default dtype of float32, since the inputs (which Keras
    # usually uses to infer the dtype) will always be int32.
    # We compute the positional encoding in float32 even if the model uses
    # float16, as many of the ops used, like log and exp, are numerically
    # unstable in float16.
    if "dtype" not in kwargs:
      kwargs["dtype"] = "float32"

    super(RelativePositionEmbedding, self).__init__(**kwargs)
    self._hidden_size = hidden_size
    self._min_timescale = min_timescale
    self._max_timescale = max_timescale

  def get_config(self):
    config = {
        "hidden_size": self._hidden_size,
        "min_timescale": self._min_timescale,
        "max_timescale": self._max_timescale,
    }
    base_config = super(RelativePositionEmbedding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

67
68
  def call(self, inputs, length=None):
    """Implements call() for the layer.
69

70
71
72
73
    Args:
      inputs: An tensor whose second dimension will be used as `length`. If
        `None`, the other `length` argument must be specified.
      length: An optional integer specifying the number of positions. If both
Hongkun Yu's avatar
Hongkun Yu committed
74
75
        `inputs` and `length` are spcified, `length` must be equal to the second
        dimension of `inputs`.
76
77
78
79

    Returns:
      A tensor in shape of [length, hidden_size].
    """
80
    if inputs is None and length is None:
Hongkun Yu's avatar
Hongkun Yu committed
81
82
      raise ValueError("If inputs is None, `length` must be set in "
                       "RelativePositionEmbedding().")
83
84
85
86
    if inputs is not None:
      input_shape = tf_utils.get_shape_list(inputs)
      if length is not None and length != input_shape[1]:
        raise ValueError(
Hongkun Yu's avatar
Hongkun Yu committed
87
            "If inputs is not None, `length` must equal to input_shape[1].")
88
89
90
91
92
93
94
95
96
97
      length = input_shape[1]
    position = tf.cast(tf.range(length), tf.float32)
    num_timescales = self._hidden_size // 2
    min_timescale, max_timescale = self._min_timescale, self._max_timescale
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.cast(num_timescales, tf.float32) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.cast(tf.range(num_timescales), tf.float32) *
        -log_timescale_increment)
Hongkun Yu's avatar
Hongkun Yu committed
98
99
100
101
    scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
        inv_timescales, 0)
    position_embeddings = tf.concat(
        [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
102
    return position_embeddings
Allen Wang's avatar
Allen Wang committed
103