attention.py 3.95 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
# Lint as: python3
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
17
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
18
19

import math
Hongkun Yu's avatar
Hongkun Yu committed
20
21
import string

Hongkun Yu's avatar
Hongkun Yu committed
22
23
24
import tensorflow as tf


Hongkun Yu's avatar
Hongkun Yu committed
25
26
27
28
EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase


29
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
30
31
32


@tf.keras.utils.register_keras_serializable(package="Text")
33
class CachedAttention(tf.keras.layers.MultiHeadAttention):
34
35
  """Attention layer with cache used for auto-agressive decoding.

36
  Arguments are the same as `MultiHeadAttention` layer.
37
38
  """

39
  def _update_cache(self, key, value, cache, decode_loop_step):
40
41
42
43
44
45
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
46
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
47
          [1, key_seq_dim, 1, 1])
48
      key = cache["key"] + key * indices
49
50
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
51
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
52
          [1, value_seq_dim, 1, 1])
53
      value = cache["value"] + value * indices
54
    else:
55
56
      key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
      value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
57
58

    # Update cache
59
60
    cache["key"] = key
    cache["value"] = value
61

62
    return key, value
63

Hongkun Yu's avatar
Hongkun Yu committed
64
  def call(self,
65
66
67
           query,
           value,
           key=None,
Hongkun Yu's avatar
Hongkun Yu committed
68
69
           attention_mask=None,
           cache=None,
70
71
           decode_loop_step=None,
           return_attention_scores=False):
72
73
74
75
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value
Hongkun Yu's avatar
Hongkun Yu committed
76

77
78
79
80
81
82
    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
83
84
    # `query` = [B, F, N ,H]
    query = self._query_dense(query)
85

86
87
    # `key` = [B, T, N, H]
    key = self._key_dense(key)
88

89
90
    # `value` = [B, T, N, H]
    value = self._value_dense(value)
91
92

    if cache:
93
      key, value = self._update_cache(key, value, cache, decode_loop_step)
94

95
    query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
xinliupitt's avatar
xinliupitt committed
96

97
98
    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
99
    attention_scores = tf.einsum(self._dot_product_equation, key, query)
100
101

    # Normalize the attention scores to probabilities.
102
    # `attention_scores` = [B, N, F, T]
103
    attention_scores = self._masked_softmax(attention_scores, attention_mask)
104
105
106

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
107
    attention_scores = self._dropout_layer(attention_scores)
108
    # `context_layer` = [B, F, N, H]
109
    attention_output = tf.einsum(self._combine_equation, attention_scores,
110
                                 value)
Hongkun Yu's avatar
Hongkun Yu committed
111
    attention_output = self._output_dense(attention_output)
112
    if return_attention_scores:
113
      return attention_output, attention_scores, cache
Hongkun Yu's avatar
Hongkun Yu committed
114
    return attention_output, cache