attention.py 3.82 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
"""Keras-based attention layer."""
16
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
17
import math
Hongkun Yu's avatar
Hongkun Yu committed
18

Hongkun Yu's avatar
Hongkun Yu committed
19
20
import tensorflow as tf

Hongkun Yu's avatar
Hongkun Yu committed
21
EinsumDense = tf.keras.layers.experimental.EinsumDense
22
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
23
24


25
@tf.keras.utils.register_keras_serializable(package="Text")
26
class CachedAttention(tf.keras.layers.MultiHeadAttention):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
  """Attention layer with cache used for autoregressive decoding.
28

29
  Arguments are the same as `tf.keras.layers.MultiHeadAttention` layer.
30
31
  """

32
  def _update_cache(self, key, value, cache, decode_loop_step):
33
34
35
36
37
38
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
39
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
40
          [1, key_seq_dim, 1, 1])
41
      key = cache["key"] + key * indices
42
43
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
44
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
45
          [1, value_seq_dim, 1, 1])
46
      value = cache["value"] + value * indices
47
    else:
48
49
      key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
      value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
50
51

    # Update cache
52
53
    cache["key"] = key
    cache["value"] = value
54

55
    return key, value
56

Hongkun Yu's avatar
Hongkun Yu committed
57
  def call(self,
58
59
60
           query,
           value,
           key=None,
Hongkun Yu's avatar
Hongkun Yu committed
61
62
           attention_mask=None,
           cache=None,
63
64
           decode_loop_step=None,
           return_attention_scores=False):
65
66
67
68
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value
Hongkun Yu's avatar
Hongkun Yu committed
69

70
71
72
73
74
75
    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
76
77
    # `query` = [B, F, N ,H]
    query = self._query_dense(query)
78

79
80
    # `key` = [B, T, N, H]
    key = self._key_dense(key)
81

82
83
    # `value` = [B, T, N, H]
    value = self._value_dense(value)
84
85

    if cache:
86
      key, value = self._update_cache(key, value, cache, decode_loop_step)
87

88
    query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
xinliupitt's avatar
xinliupitt committed
89

90
91
    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
92
    attention_scores = tf.einsum(self._dot_product_equation, key, query)
93
94

    # Normalize the attention scores to probabilities.
95
    # `attention_scores` = [B, N, F, T]
96
    attention_scores = self._masked_softmax(attention_scores, attention_mask)
97
98
99

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
100
    attention_scores = self._dropout_layer(attention_scores)
101
    # `context_layer` = [B, F, N, H]
102
    attention_output = tf.einsum(self._combine_equation, attention_scores,
103
                                 value)
Hongkun Yu's avatar
Hongkun Yu committed
104
    attention_output = self._output_dense(attention_output)
105
    if return_attention_scores:
106
      return attention_output, attention_scores, cache
Hongkun Yu's avatar
Hongkun Yu committed
107
    return attention_output, cache