attention.py 13.4 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
16
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import math
import tensorflow as tf

from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax


@tf.keras.utils.register_keras_serializable(package="Text")
30
31
class MultiHeadAttention(tf.keras.layers.Layer):
  """MultiHeadAttention layer.
Hongkun Yu's avatar
Hongkun Yu committed
32
33

  This is an implementation of multi-headed attention based on "Attention
Hongkun Yu's avatar
Hongkun Yu committed
34
35
36
  is all you Need". If `query`, `key,` `value` are the same, then
  this is self-attention. Each timestep in `query` attends to the
  corresponding sequence in `key`, and returns a fixed-width vector.
Hongkun Yu's avatar
Hongkun Yu committed
37

Hongkun Yu's avatar
Hongkun Yu committed
38
39
40
41
  This layer first projects `query`, `key` and `value`. These are
  (effectively) a list of tensors of length `num_attention_heads`, where the
  corresponding shapes are [batch_size, query_seq_length, key_size],
  [batch_size, seq_length, key_size], [batch_size, seq_length, value_size].
Hongkun Yu's avatar
Hongkun Yu committed
42
43
44
45

  Then, the query and key tensors are dot-producted and scaled. These are
  softmaxed to obtain attention probabilities. The value tensors are then
  interpolated by these probabilities, then concatenated back to a single
Hongkun Yu's avatar
Hongkun Yu committed
46
47
48
49
  tensor.

  Finally, the result tensor with the last dimension as value_size can take an
  linear projection and return.
Hongkun Yu's avatar
Hongkun Yu committed
50

51
  Arguments:
Hongkun Yu's avatar
Hongkun Yu committed
52
    num_heads: Number of attention heads.
Hongkun Yu's avatar
Hongkun Yu committed
53
54
    key_size: Size of each attention head for query and key.
    value_size:  Size of each attention head for value.
Hongkun Yu's avatar
Hongkun Yu committed
55
    dropout: Dropout probability.
Hongkun Yu's avatar
Hongkun Yu committed
56
57
58
    use_bias: Boolean, whether the dense layers use bias vectors.
    output_shape: The expected shape of an output tensor, besides the batch and
      sequence dims. If not specified, projects back to the key feature dim.
Hongkun Yu's avatar
Hongkun Yu committed
59
60
61
62
63
64
65
66
67
68
69
    kernel_initializer: Initializer for dense layer kernels.
    bias_initializer: Initializer for dense layer biases.
    kernel_regularizer: Regularizer for dense layer kernels.
    bias_regularizer: Regularizer for dense layer biases.
    activity_regularizer: Regularizer for dense layer activity.
    kernel_constraint: Constraint for dense layer kernels.
    bias_constraint: Constraint for dense layer kernels.
  """

  def __init__(self,
               num_heads,
Hongkun Yu's avatar
Hongkun Yu committed
70
71
               key_size,
               value_size=None,
Hongkun Yu's avatar
Hongkun Yu committed
72
               dropout_rate=0.0,
Hongkun Yu's avatar
Hongkun Yu committed
73
74
               use_bias=True,
               output_shape=None,
Hongkun Yu's avatar
Hongkun Yu committed
75
76
77
78
79
80
81
82
               kernel_initializer="glorot_uniform",
               bias_initializer="zeros",
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
83
    super(MultiHeadAttention, self).__init__(**kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
84
    self._num_heads = num_heads
Hongkun Yu's avatar
Hongkun Yu committed
85
86
    self._key_size = key_size
    self._value_size = value_size if value_size else key_size
Hongkun Yu's avatar
Hongkun Yu committed
87
    self._dropout_rate = dropout_rate
Hongkun Yu's avatar
Hongkun Yu committed
88
89
    self._use_bias = use_bias
    self._output_shape = output_shape
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
93
94
95
96
97
    self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self._bias_initializer = tf.keras.initializers.get(bias_initializer)
    self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
    self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
    self._bias_constraint = tf.keras.constraints.get(bias_constraint)

    self._query_dense = dense_einsum.DenseEinsum(
Hongkun Yu's avatar
Hongkun Yu committed
98
99
        output_shape=(self._num_heads, self._key_size),
        use_bias=self._use_bias,
Hongkun Yu's avatar
Hongkun Yu committed
100
101
102
103
104
105
106
107
108
109
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="query")

    self._key_dense = dense_einsum.DenseEinsum(
Hongkun Yu's avatar
Hongkun Yu committed
110
111
        output_shape=(self._num_heads, self._key_size),
        use_bias=self._use_bias,
Hongkun Yu's avatar
Hongkun Yu committed
112
113
114
115
116
117
118
119
120
121
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="key")

    self._value_dense = dense_einsum.DenseEinsum(
Hongkun Yu's avatar
Hongkun Yu committed
122
123
        output_shape=(self._num_heads, self._value_size),
        use_bias=self._use_bias,
Hongkun Yu's avatar
Hongkun Yu committed
124
125
126
127
128
129
130
131
132
133
134
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="value")

    self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
    self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
Hongkun Yu's avatar
Hongkun Yu committed
136
137
138
139
140

  def get_config(self):
    config = {
        "num_heads":
            self._num_heads,
Hongkun Yu's avatar
Hongkun Yu committed
141
142
143
144
        "key_size":
            self._key_size,
        "value_size":
            self._value_size,
Hongkun Yu's avatar
Hongkun Yu committed
145
146
        "dropout_rate":
            self._dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
147
148
149
150
        "use_bias":
            self._use_bias,
        "output_shape":
            self._output_shape,
Hongkun Yu's avatar
Hongkun Yu committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        "kernel_initializer":
            tf.keras.initializers.serialize(self._kernel_initializer),
        "bias_initializer":
            tf.keras.initializers.serialize(self._bias_initializer),
        "kernel_regularizer":
            tf.keras.regularizers.serialize(self._kernel_regularizer),
        "bias_regularizer":
            tf.keras.regularizers.serialize(self._bias_regularizer),
        "activity_regularizer":
            tf.keras.regularizers.serialize(self._activity_regularizer),
        "kernel_constraint":
            tf.keras.constraints.serialize(self._kernel_constraint),
        "bias_constraint":
            tf.keras.constraints.serialize(self._bias_constraint)
    }
166
    base_config = super(MultiHeadAttention, self).get_config()
Hongkun Yu's avatar
Hongkun Yu committed
167
168
    return dict(list(base_config.items()) + list(config.items()))

Hongkun Yu's avatar
Hongkun Yu committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  def build(self, input_shape):
    if self._output_shape:
      output_shape = self._output_shape
    else:
      input_shape = tf.TensorShape(input_shape[0])
      output_shape = input_shape[-1]
    self._output_dense = dense_einsum.DenseEinsum(
        output_shape=output_shape,
        num_summed_dimensions=2,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="attention_output")
    super(MultiHeadAttention, self).build(input_shape)

  def call(self, inputs, attention_mask=None):
    """Implements the forward pass.

    Size glossary:
      * Number of heads (H): the number of attention heads.
      * Value size (V): the size of each value embedding per head.
      * Key size (K): the size of each key embedding per head. Equally, the size
          of each query embedding per head. Typically K <= V.
      * Batch size (B).
      * Query (target) sequence length (T).
      * Value (source) sequence length (S).

    Args:
      inputs: List of the following tensors:
        * query: Query `Tensor` of shape `[B, T, dim]`.
        * value: Value `Tensor` of shape `[B, S, dim]`.
        * key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
          use `value` for both `key` and `value`, which is the most common case.
      attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
        attention to certain positions.

    Returns:
      attention_output: The result of the computation, of shape [B, F, N, V] or
        [B, F, E], where `N` is the number of heads and `E` is the query input
        last dimension.
    """
    inputs_len = len(inputs)
    if inputs_len > 3 or inputs_len < 2:
      raise ValueError(
          "Expects inputs list of length 2 or 3, namely [query, value] or "
          "[query, value, key]. "
          "Given length: %d" % inputs_len)
    query = inputs[0]
    value = inputs[1]
    key = inputs[2] if inputs_len == 3 else value
Hongkun Yu's avatar
Hongkun Yu committed
223
224
225

    #   N = `num_attention_heads`
    #   H = `size_per_head`
Hongkun Yu's avatar
Hongkun Yu committed
226
227
    # `query_tensor` = [B, T, N ,H]
    query_tensor = self._query_dense(query)
Hongkun Yu's avatar
Hongkun Yu committed
228

Hongkun Yu's avatar
Hongkun Yu committed
229
230
    # `key_tensor` = [B, S, N, H]
    key_tensor = self._key_dense(key)
Hongkun Yu's avatar
Hongkun Yu committed
231

Hongkun Yu's avatar
Hongkun Yu committed
232
233
    # `value_tensor` = [B, S, N, H]
    value_tensor = self._value_dense(value)
Hongkun Yu's avatar
Hongkun Yu committed
234
235
236

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
Hongkun Yu's avatar
Hongkun Yu committed
237
    attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
Hongkun Yu's avatar
Hongkun Yu committed
238
    attention_scores = tf.multiply(attention_scores,
Hongkun Yu's avatar
Hongkun Yu committed
239
                                   1.0 / math.sqrt(float(self._key_size)))
Hongkun Yu's avatar
Hongkun Yu committed
240
241

    # Normalize the attention scores to probabilities.
Hongkun Yu's avatar
Hongkun Yu committed
242
    # `attention_probs` = [B, N, T, S]
Hongkun Yu's avatar
Hongkun Yu committed
243
244
245
246
247
248
    attention_probs = self._masked_softmax([attention_scores, attention_mask])

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self._dropout(attention_probs)

Hongkun Yu's avatar
Hongkun Yu committed
249
250
251
252
253
254
    # `context_layer` = [B, T, N, H]
    attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs,
                                 value_tensor)

    attention_output = self._output_dense(attention_output)
    return attention_output
255
256
257


@tf.keras.utils.register_keras_serializable(package="Text")
258
class CachedAttention(MultiHeadAttention):
259
260
  """Attention layer with cache used for auto-agressive decoding.

261
  Arguments are the same as `MultiHeadAttention` layer.
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
  """

  def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
          [1, key_seq_dim, 1, 1])
      key_tensor = cache["key"] + key_tensor * indices
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
          [1, value_seq_dim, 1, 1])
      value_tensor = cache["value"] + value_tensor * indices
    else:
      key_tensor = tf.concat(
          [tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
      value_tensor = tf.concat(
          [tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)

    # Update cache
    cache["key"] = key_tensor
    cache["value"] = value_tensor

    return key_tensor, value_tensor

  def call(self, inputs, decode_loop_step=None):
    from_tensor = inputs[0]
    to_tensor = inputs[1]
    attention_mask = inputs[2] if len(inputs) >= 3 else None
    cache = inputs[3] if len(inputs) >= 4 else None
    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
    # `query_tensor` = [B, F, N ,H]
    query_tensor = self._query_dense(from_tensor)

    # `key_tensor` = [B, T, N, H]
    key_tensor = self._key_dense(to_tensor)

    # `value_tensor` = [B, T, N, H]
    value_tensor = self._value_dense(to_tensor)

    if cache:
      key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
                                                    cache, decode_loop_step)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
    attention_scores = tf.multiply(attention_scores,
Hongkun Yu's avatar
Hongkun Yu committed
319
                                   1.0 / math.sqrt(float(self._key_size)))
320
321
322
323
324
325
326
327
328

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = self._masked_softmax([attention_scores, attention_mask])

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self._dropout(attention_probs)
    # `context_layer` = [B, F, N, H]
Hongkun Yu's avatar
Hongkun Yu committed
329
330
331
332
    attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
                                 value_tensor)
    attention_output = self._output_dense(attention_output)
    return attention_output, cache