attention.py 17.1 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
# Lint as: python3
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
17
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
18
19
20
21
22
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
23
import collections
Hongkun Yu's avatar
Hongkun Yu committed
24
import math
Hongkun Yu's avatar
Hongkun Yu committed
25
26
27
import string

import numpy as np
Hongkun Yu's avatar
Hongkun Yu committed
28
29
30
31
import tensorflow as tf

from official.nlp.modeling.layers import masked_softmax

Hongkun Yu's avatar
Hongkun Yu committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase


def _build_attention_equation(qkv_rank, attn_axes):
  """Builds einsum equations for the attention computation.

  Query, key, value inputs after projection are expected to have the shape as:
  (bs, <non-attention dims>, <attention dims>, num_heads, channels).
  bs and <non-attention dims> are treated as <batch dims>.
  The attention operations can be generalized:
  (1) Query-key dot product:
  (<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
  <key attention dims>, num_heads, channels) -> (<batch dims>,
  num_heads, <query attention dims>, <key attention dims>)
  (2) Combination:
  (<batch dims>, num_heads, <query attention dims>, <key attention dims>),
  (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
  <query attention dims>, num_heads, channels)

  Args:
    qkv_rank: the rank of query, key, value tensors.
    attn_axes: a list/tuple of axes, [1, rank), that will do attention.
  Returns:
    Einsum equations.
  """
  target_notation = _CHR_IDX[:qkv_rank]
  # `batch_dims` includes the head dim.
  batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))

  letter_offset = qkv_rank
  source_notation = ""
  for i in range(qkv_rank):
    if i in batch_dims or i == qkv_rank - 1:
      source_notation += target_notation[i]
    else:
      source_notation += _CHR_IDX[letter_offset]
      letter_offset += 1

  product_notation = "".join([target_notation[i] for i in batch_dims] +
                             [target_notation[i] for i in attn_axes] +
                             [source_notation[i] for i in attn_axes])
  dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
                                        product_notation)
  combine_equation = "%s,%s->%s" % (product_notation, source_notation,
                                    target_notation)
  return dot_product_equation, combine_equation


def _build_proj_equation(free_dims, bound_dims, output_dims):
  """Builds an einsum equation for projections inside multi-head attention."""
  input_str = ""
  kernel_str = ""
  output_str = ""
  bias_axes = ""
  letter_offset = 0
  for i in range(free_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    output_str += char

  letter_offset += free_dims
  for i in range(bound_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    kernel_str += char

  letter_offset += bound_dims
  for i in range(output_dims):
    char = _CHR_IDX[i + letter_offset]
    kernel_str += char
    output_str += char
    bias_axes += char
  equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
  # The output rank does not consider the batch dimension.
  output_rank = len(output_str) - 1

  return equation, bias_axes, output_rank


def _get_output_shape(output_rank, known_last_dims):
  return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)

Hongkun Yu's avatar
Hongkun Yu committed
115
116

@tf.keras.utils.register_keras_serializable(package="Text")
117
118
class MultiHeadAttention(tf.keras.layers.Layer):
  """MultiHeadAttention layer.
Hongkun Yu's avatar
Hongkun Yu committed
119
120

  This is an implementation of multi-headed attention based on "Attention
Hongkun Yu's avatar
Hongkun Yu committed
121
122
123
  is all you Need". If `query`, `key,` `value` are the same, then
  this is self-attention. Each timestep in `query` attends to the
  corresponding sequence in `key`, and returns a fixed-width vector.
Hongkun Yu's avatar
Hongkun Yu committed
124

Hongkun Yu's avatar
Hongkun Yu committed
125
126
127
128
  This layer first projects `query`, `key` and `value`. These are
  (effectively) a list of tensors of length `num_attention_heads`, where the
  corresponding shapes are [batch_size, query_seq_length, key_size],
  [batch_size, seq_length, key_size], [batch_size, seq_length, value_size].
Hongkun Yu's avatar
Hongkun Yu committed
129
130
131
132

  Then, the query and key tensors are dot-producted and scaled. These are
  softmaxed to obtain attention probabilities. The value tensors are then
  interpolated by these probabilities, then concatenated back to a single
Hongkun Yu's avatar
Hongkun Yu committed
133
134
135
136
  tensor.

  Finally, the result tensor with the last dimension as value_size can take an
  linear projection and return.
Hongkun Yu's avatar
Hongkun Yu committed
137

138
  Arguments:
Hongkun Yu's avatar
Hongkun Yu committed
139
    num_heads: Number of attention heads.
Hongkun Yu's avatar
Hongkun Yu committed
140
141
    key_size: Size of each attention head for query and key.
    value_size:  Size of each attention head for value.
Hongkun Yu's avatar
Hongkun Yu committed
142
    dropout: Dropout probability.
Hongkun Yu's avatar
Hongkun Yu committed
143
    use_bias: Boolean, whether the dense layers use bias vectors/matrices.
Hongkun Yu's avatar
Hongkun Yu committed
144
145
    output_shape: The expected shape of an output tensor, besides the batch and
      sequence dims. If not specified, projects back to the key feature dim.
Hongkun Yu's avatar
Hongkun Yu committed
146
147
148
149
150
151
152
153
154
155
156
    kernel_initializer: Initializer for dense layer kernels.
    bias_initializer: Initializer for dense layer biases.
    kernel_regularizer: Regularizer for dense layer kernels.
    bias_regularizer: Regularizer for dense layer biases.
    activity_regularizer: Regularizer for dense layer activity.
    kernel_constraint: Constraint for dense layer kernels.
    bias_constraint: Constraint for dense layer kernels.
  """

  def __init__(self,
               num_heads,
Hongkun Yu's avatar
Hongkun Yu committed
157
158
               key_size,
               value_size=None,
Hongkun Yu's avatar
Hongkun Yu committed
159
               dropout_rate=0.0,
Hongkun Yu's avatar
Hongkun Yu committed
160
161
               use_bias=True,
               output_shape=None,
Hongkun Yu's avatar
Hongkun Yu committed
162
163
164
165
166
167
168
169
               kernel_initializer="glorot_uniform",
               bias_initializer="zeros",
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
170
    super(MultiHeadAttention, self).__init__(**kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
171
    self._num_heads = num_heads
Hongkun Yu's avatar
Hongkun Yu committed
172
173
    self._key_size = key_size
    self._value_size = value_size if value_size else key_size
Hongkun Yu's avatar
Hongkun Yu committed
174
    self._dropout_rate = dropout_rate
Hongkun Yu's avatar
Hongkun Yu committed
175
176
    self._use_bias = use_bias
    self._output_shape = output_shape
Hongkun Yu's avatar
Hongkun Yu committed
177
178
179
180
181
182
183
184
    self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self._bias_initializer = tf.keras.initializers.get(bias_initializer)
    self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
    self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
    self._bias_constraint = tf.keras.constraints.get(bias_constraint)

    self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
185
    self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
Hongkun Yu's avatar
Hongkun Yu committed
186
187
188
189
190

  def get_config(self):
    config = {
        "num_heads":
            self._num_heads,
Hongkun Yu's avatar
Hongkun Yu committed
191
192
193
194
        "key_size":
            self._key_size,
        "value_size":
            self._value_size,
Hongkun Yu's avatar
Hongkun Yu committed
195
196
        "dropout_rate":
            self._dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
197
198
199
200
        "use_bias":
            self._use_bias,
        "output_shape":
            self._output_shape,
Hongkun Yu's avatar
Hongkun Yu committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        "kernel_initializer":
            tf.keras.initializers.serialize(self._kernel_initializer),
        "bias_initializer":
            tf.keras.initializers.serialize(self._bias_initializer),
        "kernel_regularizer":
            tf.keras.regularizers.serialize(self._kernel_regularizer),
        "bias_regularizer":
            tf.keras.regularizers.serialize(self._bias_regularizer),
        "activity_regularizer":
            tf.keras.regularizers.serialize(self._activity_regularizer),
        "kernel_constraint":
            tf.keras.constraints.serialize(self._kernel_constraint),
        "bias_constraint":
            tf.keras.constraints.serialize(self._bias_constraint)
    }
216
    base_config = super(MultiHeadAttention, self).get_config()
Hongkun Yu's avatar
Hongkun Yu committed
217
218
    return dict(list(base_config.items()) + list(config.items()))

Hongkun Yu's avatar
Hongkun Yu committed
219
  def build(self, input_shape):
Hongkun Yu's avatar
Hongkun Yu committed
220
221
222
223
224
225
226
227
228
229
230
231
    inputs_len = len(input_shape)
    if inputs_len > 3 or inputs_len < 2:
      raise ValueError(
          "Expects inputs list of length 2 or 3, namely [query, value] or "
          "[query, value, key]. "
          "Given length: %d" % inputs_len)
    tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
    query_shape = tensor_shapes[0]
    value_shape = tensor_shapes[1]
    key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape

    common_kwargs = dict(
Hongkun Yu's avatar
Hongkun Yu committed
232
233
234
235
236
237
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
Hongkun Yu's avatar
Hongkun Yu committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        bias_constraint=self._bias_constraint)

    free_dims = query_shape.rank - 1
    einsum_equation, bias_axes, output_rank = _build_proj_equation(
        free_dims, bound_dims=1, output_dims=2)
    self._query_dense = EinsumDense(
        einsum_equation,
        output_shape=_get_output_shape(output_rank,
                                       [self._num_heads, self._key_size]),
        bias_axes=bias_axes if self._use_bias else None,
        name="query",
        **common_kwargs)
    einsum_equation, bias_axes, output_rank = _build_proj_equation(
        key_shape.rank - 1, bound_dims=1, output_dims=2)
    self._key_dense = EinsumDense(
        einsum_equation,
        output_shape=_get_output_shape(output_rank,
                                       [self._num_heads, self._key_size]),
        bias_axes=bias_axes if self._use_bias else None,
        name="key",
        **common_kwargs)
    einsum_equation, bias_axes, output_rank = _build_proj_equation(
        value_shape.rank - 1, bound_dims=1, output_dims=2)
    self._value_dense = EinsumDense(
        einsum_equation,
        output_shape=_get_output_shape(output_rank,
                                       [self._num_heads, self._value_size]),
        bias_axes=bias_axes if self._use_bias else None,
        name="value",
        **common_kwargs)
    self._dot_product_equation, self._combine_equation = (
        _build_attention_equation(output_rank + 1, attn_axes=(1,)))

    if self._output_shape:
      if not isinstance(self._output_shape, collections.abc.Sized):
        output_shape = [self._output_shape]
      else:
        output_shape = self._output_shape
    else:
      output_shape = [query_shape[-1]]
    einsum_equation, bias_axes, output_rank = _build_proj_equation(
        free_dims, bound_dims=2, output_dims=len(output_shape))
    self._output_dense = EinsumDense(
        einsum_equation,
        output_shape=_get_output_shape(output_rank, output_shape),
        bias_axes=bias_axes if self._use_bias else None,
        name="attention_output",
        **common_kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    super(MultiHeadAttention, self).build(input_shape)

  def call(self, inputs, attention_mask=None):
    """Implements the forward pass.

    Size glossary:
      * Number of heads (H): the number of attention heads.
      * Value size (V): the size of each value embedding per head.
      * Key size (K): the size of each key embedding per head. Equally, the size
          of each query embedding per head. Typically K <= V.
      * Batch size (B).
      * Query (target) sequence length (T).
      * Value (source) sequence length (S).

    Args:
      inputs: List of the following tensors:
        * query: Query `Tensor` of shape `[B, T, dim]`.
        * value: Value `Tensor` of shape `[B, S, dim]`.
        * key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
          use `value` for both `key` and `value`, which is the most common case.
      attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
        attention to certain positions.

    Returns:
      attention_output: The result of the computation, of shape [B, F, N, V] or
        [B, F, E], where `N` is the number of heads and `E` is the query input
        last dimension.
    """
    inputs_len = len(inputs)
    if inputs_len > 3 or inputs_len < 2:
      raise ValueError(
          "Expects inputs list of length 2 or 3, namely [query, value] or "
          "[query, value, key]. "
          "Given length: %d" % inputs_len)
    query = inputs[0]
    value = inputs[1]
    key = inputs[2] if inputs_len == 3 else value
Hongkun Yu's avatar
Hongkun Yu committed
323
324
325

    #   N = `num_attention_heads`
    #   H = `size_per_head`
Hongkun Yu's avatar
Hongkun Yu committed
326
327
    # `query_tensor` = [B, T, N ,H]
    query_tensor = self._query_dense(query)
Hongkun Yu's avatar
Hongkun Yu committed
328

Hongkun Yu's avatar
Hongkun Yu committed
329
330
    # `key_tensor` = [B, S, N, H]
    key_tensor = self._key_dense(key)
Hongkun Yu's avatar
Hongkun Yu committed
331

Hongkun Yu's avatar
Hongkun Yu committed
332
333
    # `value_tensor` = [B, S, N, H]
    value_tensor = self._value_dense(value)
Hongkun Yu's avatar
Hongkun Yu committed
334
335
336

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
Hongkun Yu's avatar
Hongkun Yu committed
337
338
    attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
                                 query_tensor)
Hongkun Yu's avatar
Hongkun Yu committed
339
    attention_scores = tf.multiply(attention_scores,
Hongkun Yu's avatar
Hongkun Yu committed
340
                                   1.0 / math.sqrt(float(self._key_size)))
Hongkun Yu's avatar
Hongkun Yu committed
341
342

    # Normalize the attention scores to probabilities.
Hongkun Yu's avatar
Hongkun Yu committed
343
    # `attention_probs` = [B, N, T, S]
Hongkun Yu's avatar
Hongkun Yu committed
344
345
346
347
348
349
    attention_probs = self._masked_softmax([attention_scores, attention_mask])

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self._dropout(attention_probs)

Hongkun Yu's avatar
Hongkun Yu committed
350
    # `context_layer` = [B, T, N, H]
Hongkun Yu's avatar
Hongkun Yu committed
351
    attention_output = tf.einsum(self._combine_equation, attention_probs,
Hongkun Yu's avatar
Hongkun Yu committed
352
353
354
355
                                 value_tensor)

    attention_output = self._output_dense(attention_output)
    return attention_output
356
357
358


@tf.keras.utils.register_keras_serializable(package="Text")
359
class CachedAttention(MultiHeadAttention):
360
361
  """Attention layer with cache used for auto-agressive decoding.

362
  Arguments are the same as `MultiHeadAttention` layer.
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
  """

  def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
          [1, key_seq_dim, 1, 1])
      key_tensor = cache["key"] + key_tensor * indices
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
          [1, value_seq_dim, 1, 1])
      value_tensor = cache["value"] + value_tensor * indices
    else:
      key_tensor = tf.concat(
          [tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
      value_tensor = tf.concat(
          [tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)

    # Update cache
    cache["key"] = key_tensor
    cache["value"] = value_tensor

    return key_tensor, value_tensor

Hongkun Yu's avatar
Hongkun Yu committed
392
393
394
395
396
  def call(self,
           inputs,
           attention_mask=None,
           cache=None,
           decode_loop_step=None):
397
398
    from_tensor = inputs[0]
    to_tensor = inputs[1]
Hongkun Yu's avatar
Hongkun Yu committed
399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
    # `query_tensor` = [B, F, N ,H]
    query_tensor = self._query_dense(from_tensor)

    # `key_tensor` = [B, T, N, H]
    key_tensor = self._key_dense(to_tensor)

    # `value_tensor` = [B, T, N, H]
    value_tensor = self._value_dense(to_tensor)

    if cache:
      key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
                                                    cache, decode_loop_step)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
Hongkun Yu's avatar
Hongkun Yu committed
421
422
    attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
                                 query_tensor)
423
    attention_scores = tf.multiply(attention_scores,
Hongkun Yu's avatar
Hongkun Yu committed
424
                                   1.0 / math.sqrt(float(self._key_size)))
425
426
427
428
429
430
431
432
433

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = self._masked_softmax([attention_scores, attention_mask])

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self._dropout(attention_probs)
    # `context_layer` = [B, F, N, H]
Hongkun Yu's avatar
Hongkun Yu committed
434
    attention_output = tf.einsum(self._combine_equation, attention_probs,
Hongkun Yu's avatar
Hongkun Yu committed
435
436
437
                                 value_tensor)
    attention_output = self._output_dense(attention_output)
    return attention_output, cache