attention_layer.py 7.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


class Attention(tf.keras.layers.Layer):
  """Multi-headed attention layer."""

  def __init__(self, hidden_size, num_heads, attention_dropout):
    """Initialize Attention.

    Args:
      hidden_size: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
    """
    if hidden_size % num_heads:
      raise ValueError(
          "Hidden size ({}) must be divisible by the number of heads ({})."
          .format(hidden_size, num_heads))

    super(Attention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout

  def build(self, input_shape):
    """Builds the layer."""
    # Layers for linearly projecting the queries, keys, and values.
    self.q_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="q")
    self.k_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="k")
    self.v_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="v")
    self.output_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="output_transform")
    super(Attention, self).build(input_shape)

  def get_config(self):
    return {
        "hidden_size": self.hidden_size,
        "num_heads": self.num_heads,
        "attention_dropout": self.attention_dropout,
    }

  def split_heads(self, x):
    """Split x into different heads, and transpose the resulting value.

    The tensor is transposed to insure the inner dimensions hold the correct
    values during the matrix multiplication.

    Args:
      x: A tensor with shape [batch_size, length, hidden_size]

    Returns:
      A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
    """
    with tf.name_scope("split_heads"):
      batch_size = tf.shape(x)[0]
      length = tf.shape(x)[1]

      # Calculate depth of last dimension after it has been split.
      depth = (self.hidden_size // self.num_heads)

      # Split the last dimension
      x = tf.reshape(x, [batch_size, length, self.num_heads, depth])

      # Transpose the result
      return tf.transpose(x, [0, 2, 1, 3])

  def combine_heads(self, x):
    """Combine tensor that has been split.

    Args:
      x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]

    Returns:
      A tensor with shape [batch_size, length, hidden_size]
    """
    with tf.name_scope("combine_heads"):
      batch_size = tf.shape(x)[0]
      length = tf.shape(x)[2]
      x = tf.transpose(x, [0, 2, 1, 3])  # --> [batch, length, num_heads, depth]
      return tf.reshape(x, [batch_size, length, self.hidden_size])

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
105
  def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
106
107
108
    """Apply attention mechanism to x and y.

    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
110
111
112
113
114
115
      x: A tensor with shape [batch_size, length_x, hidden_size].
      y: A tensor with shape [batch_size, length_y, hidden_size].
      bias: A bool, the attention bias that will be added to the result of the
        dot product.
      training: A bool, whether in training mode or not.
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
116
117
118
            {"k": tensor with shape [batch_size, i, key_channels],
             "v": tensor with shape [batch_size, i, value_channels]}
        where i is the current decoded length.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
119
120
      decode_loop_step: An integer, step number of the decoding loop. Used only
        for autoregressive inference on TPU.
121
122
123
124

    Returns:
      Attention layer output with shape [batch_size, length_x, hidden_size]
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
125
126
127
128
129
130
131
    # Linearly project the query, key and value using different learned
    # projections. This is in preparation of splitting them into multiple
    # heads. Multi-head attention uses multiple queries, keys, and values
    # rather than regular attention (which uses a single query, key, value).
    query = self.q_dense_layer(x)
    key = self.k_dense_layer(y)
    value = self.v_dense_layer(y)
132
133
134

    if cache is not None:
      # Combine cached keys and values with new keys and values.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
      if decode_loop_step is not None:
        cache_k_shape = cache["k"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
            [1, cache_k_shape[1], 1])
        key = cache["k"] + key * indices
        cache_v_shape = cache["v"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
            [1, cache_v_shape[1], 1])
        value = cache["v"] + value * indices
      else:
        key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
        value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
149
150

      # Update cache
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
152
      cache["k"] = key
      cache["v"] = value
153

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
154
155
156
157
    # Split query, key, value into heads.
    query = self.split_heads(query)
    key = self.split_heads(key)
    value = self.split_heads(value)
158

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
159
160
    # Scale query to prevent the dot product between query and key from growing
    # too large.
161
    depth = (self.hidden_size // self.num_heads)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
162
    query *= depth ** -0.5
163
164

    # Calculate dot product attention
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
165
    logits = tf.matmul(query, key, transpose_b=True)
166
    logits += bias
167
168
169
170
    # Note that softmax internally performs math operations using float32
    # for numeric stability. When training with float16, we keep the input
    # and output in float16 for better performance.
    weights = tf.nn.softmax(logits, name="attention_weights")
171
172
    if training:
      weights = tf.nn.dropout(weights, rate=self.attention_dropout)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
173
    attention_output = tf.matmul(weights, value)
174
175
176
177
178
179
180
181
182
183
184
185

    # Recombine heads --> [batch_size, length, hidden_size]
    attention_output = self.combine_heads(attention_output)

    # Run the combined outputs through another linear projection layer.
    attention_output = self.output_dense_layer(attention_output)
    return attention_output


class SelfAttention(Attention):
  """Multiheaded self-attention layer."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
186
187
188
  def call(self, x, bias, training, cache=None, decode_loop_step=None):
    return super(SelfAttention, self).call(x, x, bias, training, cache,
                                           decode_loop_step)