attention_layer.py 6.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
from official.nlp.modeling import layers
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


class Attention(tf.keras.layers.Layer):
  """Multi-headed attention layer."""

  def __init__(self, hidden_size, num_heads, attention_dropout):
    """Initialize Attention.

    Args:
      hidden_size: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
    """
    if hidden_size % num_heads:
      raise ValueError(
          "Hidden size ({}) must be divisible by the number of heads ({})."
          .format(hidden_size, num_heads))

    super(Attention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout

  def build(self, input_shape):
    """Builds the layer."""
    # Layers for linearly projecting the queries, keys, and values.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
49
    size_per_head = self.hidden_size // self.num_heads
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    self.query_dense_layer = layers.DenseEinsum(
        output_shape=(self.num_heads, size_per_head),
        kernel_initializer="glorot_uniform",
        use_bias=False,
        name="query")
    self.key_dense_layer = layers.DenseEinsum(
        output_shape=(self.num_heads, size_per_head),
        kernel_initializer="glorot_uniform",
        use_bias=False,
        name="key")
    self.value_dense_layer = layers.DenseEinsum(
        output_shape=(self.num_heads, size_per_head),
        kernel_initializer="glorot_uniform",
        use_bias=False,
        name="value")
    self.output_dense_layer = layers.DenseEinsum(
        output_shape=self.hidden_size,
        num_summed_dimensions=2,
        kernel_initializer="glorot_uniform",
        use_bias=False,
        name="output_transform")
71
72
73
74
75
76
77
78
79
    super(Attention, self).build(input_shape)

  def get_config(self):
    return {
        "hidden_size": self.hidden_size,
        "num_heads": self.num_heads,
        "attention_dropout": self.attention_dropout,
    }

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
80
81
82
  def call(self, query_input, source_input, bias, training, cache=None,
           decode_loop_step=None):
    """Apply attention mechanism to query_input and source_input.
83
84

    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
85
86
87
88
89
      query_input: A tensor with shape [batch_size, length_query, hidden_size].
      source_input: A tensor with shape [batch_size, length_source,
        hidden_size].
      bias: A tensor with shape [batch_size, 1, length_query, length_source],
        the attention bias that will be added to the result of the dot product.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
90
91
92
      training: A bool, whether in training mode or not.
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
94
95
96
            {"k": tensor with shape [batch_size, i, heads, dim_per_head],
             "v": tensor with shape [batch_size, i, heads, dim_per_head]}
        where i is the current decoded length for non-padded decode, or max
        sequence length for padded decode.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
97
98
      decode_loop_step: An integer, step number of the decoding loop. Used only
        for autoregressive inference on TPU.
99
100

    Returns:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
101
      Attention layer output with shape [batch_size, length_query, hidden_size]
102
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
    # Linearly project the query, key and value using different learned
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
104
105
106
107
108
    # projections. Splitting heads is automatically done during the linear
    # projections --> [batch_size, length, num_heads, dim_per_head].
    query = self.query_dense_layer(query_input)
    key = self.key_dense_layer(source_input)
    value = self.value_dense_layer(source_input)
109
110
111

    if cache is not None:
      # Combine cached keys and values with new keys and values.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
112
113
114
115
      if decode_loop_step is not None:
        cache_k_shape = cache["k"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
            [1, cache_k_shape[1], 1, 1])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
117
118
119
120
        key = cache["k"] + key * indices
        cache_v_shape = cache["v"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
121
            [1, cache_v_shape[1], 1, 1])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
122
123
124
125
        value = cache["v"] + value * indices
      else:
        key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
        value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
126
127

      # Update cache
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
128
129
      cache["k"] = key
      cache["v"] = value
130

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
131
132
    # Scale query to prevent the dot product between query and key from growing
    # too large.
133
    depth = (self.hidden_size // self.num_heads)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
134
    query *= depth ** -0.5
135
136

    # Calculate dot product attention
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
    logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
138
    logits += bias
139
140
141
142
    # Note that softmax internally performs math operations using float32
    # for numeric stability. When training with float16, we keep the input
    # and output in float16 for better performance.
    weights = tf.nn.softmax(logits, name="attention_weights")
143
144
    if training:
      weights = tf.nn.dropout(weights, rate=self.attention_dropout)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
145
    attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
146

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
147
148
    # Run the outputs through another linear projection layer. Recombining heads
    # is automatically done --> [batch_size, length, hidden_size]
149
150
151
152
153
154
155
    attention_output = self.output_dense_layer(attention_output)
    return attention_output


class SelfAttention(Attention):
  """Multiheaded self-attention layer."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
156
157
158
159
  def call(self, query_input, bias, training, cache=None,
           decode_loop_step=None):
    return super(SelfAttention, self).call(
        query_input, query_input, bias, training, cache, decode_loop_step)