attention_layer.py 6.88 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Implementation of multiheaded attention and self-attention layers."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
import math

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import tensorflow as tf


class Attention(tf.keras.layers.Layer):
  """Multi-headed attention layer."""

  def __init__(self, hidden_size, num_heads, attention_dropout):
    """Initialize Attention.

    Args:
      hidden_size: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
    """
    if hidden_size % num_heads:
      raise ValueError(
          "Hidden size ({}) must be divisible by the number of heads ({})."
          .format(hidden_size, num_heads))

    super(Attention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout

  def build(self, input_shape):
    """Builds the layer."""
    # Layers for linearly projecting the queries, keys, and values.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
45
    size_per_head = self.hidden_size // self.num_heads
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
46
47
48
49
50

    def _glorot_initializer(fan_in, fan_out):
      limit = math.sqrt(6.0 / (fan_in + fan_out))
      return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
51
    attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
                                                self.hidden_size)
Hongkun Yu's avatar
Hongkun Yu committed
53
54
55
    self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
56
        kernel_initializer=attention_initializer,
Hongkun Yu's avatar
Hongkun Yu committed
57
        bias_axes=None,
58
        name="query")
Hongkun Yu's avatar
Hongkun Yu committed
59
60
61
    self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
62
        kernel_initializer=attention_initializer,
Hongkun Yu's avatar
Hongkun Yu committed
63
        bias_axes=None,
64
        name="key")
Hongkun Yu's avatar
Hongkun Yu committed
65
66
67
    self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
68
        kernel_initializer=attention_initializer,
Hongkun Yu's avatar
Hongkun Yu committed
69
        bias_axes=None,
70
        name="value")
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
71
72

    output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
Hongkun Yu's avatar
Hongkun Yu committed
73
74
75
    self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTNH,NHE->BTE",
        output_shape=(None, self.hidden_size),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
        kernel_initializer=output_initializer,
Hongkun Yu's avatar
Hongkun Yu committed
77
        bias_axes=None,
78
        name="output_transform")
79
80
81
82
83
84
85
86
87
    super(Attention, self).build(input_shape)

  def get_config(self):
    return {
        "hidden_size": self.hidden_size,
        "num_heads": self.num_heads,
        "attention_dropout": self.attention_dropout,
    }

Hongkun Yu's avatar
Hongkun Yu committed
88
89
90
91
92
93
  def call(self,
           query_input,
           source_input,
           bias,
           training,
           cache=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
94
95
           decode_loop_step=None):
    """Apply attention mechanism to query_input and source_input.
96
97

    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
98
99
100
101
102
      query_input: A tensor with shape [batch_size, length_query, hidden_size].
      source_input: A tensor with shape [batch_size, length_source,
        hidden_size].
      bias: A tensor with shape [batch_size, 1, length_query, length_source],
        the attention bias that will be added to the result of the dot product.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
104
105
      training: A bool, whether in training mode or not.
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
106
            {"k": tensor with shape [batch_size, i, heads, dim_per_head],
Hongkun Yu's avatar
Hongkun Yu committed
107
108
109
             "v": tensor with shape [batch_size, i, heads, dim_per_head]} where
               i is the current decoded length for non-padded decode, or max
               sequence length for padded decode.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
110
111
      decode_loop_step: An integer, step number of the decoding loop. Used only
        for autoregressive inference on TPU.
112
113

    Returns:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
114
      Attention layer output with shape [batch_size, length_query, hidden_size]
115
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
    # Linearly project the query, key and value using different learned
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
117
118
119
120
121
    # projections. Splitting heads is automatically done during the linear
    # projections --> [batch_size, length, num_heads, dim_per_head].
    query = self.query_dense_layer(query_input)
    key = self.key_dense_layer(source_input)
    value = self.value_dense_layer(source_input)
122
123
124

    if cache is not None:
      # Combine cached keys and values with new keys and values.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
125
126
127
128
      if decode_loop_step is not None:
        cache_k_shape = cache["k"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
            [1, cache_k_shape[1], 1, 1])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
132
133
        key = cache["k"] + key * indices
        cache_v_shape = cache["v"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
134
            [1, cache_v_shape[1], 1, 1])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
136
137
138
        value = cache["v"] + value * indices
      else:
        key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
        value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
139
140

      # Update cache
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
141
142
      cache["k"] = key
      cache["v"] = value
143

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
144
145
    # Scale query to prevent the dot product between query and key from growing
    # too large.
146
    depth = (self.hidden_size // self.num_heads)
Hongkun Yu's avatar
Hongkun Yu committed
147
    query *= depth**-0.5
148
149

    # Calculate dot product attention
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
150
    logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
151
    logits += bias
152
153
154
155
    # Note that softmax internally performs math operations using float32
    # for numeric stability. When training with float16, we keep the input
    # and output in float16 for better performance.
    weights = tf.nn.softmax(logits, name="attention_weights")
156
157
    if training:
      weights = tf.nn.dropout(weights, rate=self.attention_dropout)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
158
    attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
159

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
160
161
    # Run the outputs through another linear projection layer. Recombining heads
    # is automatically done --> [batch_size, length, hidden_size]
162
163
164
165
166
167
168
    attention_output = self.output_dense_layer(attention_output)
    return attention_output


class SelfAttention(Attention):
  """Multiheaded self-attention layer."""

Hongkun Yu's avatar
Hongkun Yu committed
169
170
171
172
173
  def call(self,
           query_input,
           bias,
           training,
           cache=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
174
           decode_loop_step=None):
Hongkun Yu's avatar
Hongkun Yu committed
175
176
    return super(SelfAttention, self).call(query_input, query_input, bias,
                                           training, cache, decode_loop_step)