attention.py 15.8 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
# Lint as: python3
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
17
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
18
import math
Allen Wang's avatar
Allen Wang committed
19
import string
Hongkun Yu's avatar
Hongkun Yu committed
20

Hongkun Yu's avatar
Hongkun Yu committed
21
22
23
import tensorflow as tf


Hongkun Yu's avatar
Hongkun Yu committed
24
EinsumDense = tf.keras.layers.experimental.EinsumDense
25
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
Allen Wang's avatar
Allen Wang committed
26
_CHR_IDX = string.ascii_lowercase
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def _large_compatible_negative(tensor_type):
  """Large negative number as Tensor.

  This function is necessary because the standard value for epsilon
  in this module (-1e9) cannot be represented using tf.float16

  Args:
    tensor_type: a dtype to determine the type.

  Returns:
    a large negative number.
  """
  if tensor_type == tf.float16:
    return tf.float16.min
  return -1e9


46
@tf.keras.utils.register_keras_serializable(package="Text")
47
class CachedAttention(tf.keras.layers.MultiHeadAttention):
48
49
  """Attention layer with cache used for auto-agressive decoding.

50
  Arguments are the same as `MultiHeadAttention` layer.
51
52
  """

53
  def _update_cache(self, key, value, cache, decode_loop_step):
54
55
56
57
58
59
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
60
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
61
          [1, key_seq_dim, 1, 1])
62
      key = cache["key"] + key * indices
63
64
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
65
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
66
          [1, value_seq_dim, 1, 1])
67
      value = cache["value"] + value * indices
68
    else:
69
70
      key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
      value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
71
72

    # Update cache
73
74
    cache["key"] = key
    cache["value"] = value
75

76
    return key, value
77

Hongkun Yu's avatar
Hongkun Yu committed
78
  def call(self,
79
80
81
           query,
           value,
           key=None,
Hongkun Yu's avatar
Hongkun Yu committed
82
83
           attention_mask=None,
           cache=None,
84
85
           decode_loop_step=None,
           return_attention_scores=False):
86
87
88
89
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value
Hongkun Yu's avatar
Hongkun Yu committed
90

91
92
93
94
95
96
    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
97
98
    # `query` = [B, F, N ,H]
    query = self._query_dense(query)
99

100
101
    # `key` = [B, T, N, H]
    key = self._key_dense(key)
102

103
104
    # `value` = [B, T, N, H]
    value = self._value_dense(value)
105
106

    if cache:
107
      key, value = self._update_cache(key, value, cache, decode_loop_step)
108

109
    query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
xinliupitt's avatar
xinliupitt committed
110

111
112
    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
113
    attention_scores = tf.einsum(self._dot_product_equation, key, query)
114
115

    # Normalize the attention scores to probabilities.
116
    # `attention_scores` = [B, N, F, T]
117
    attention_scores = self._masked_softmax(attention_scores, attention_mask)
118
119
120

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
121
    attention_scores = self._dropout_layer(attention_scores)
122
    # `context_layer` = [B, F, N, H]
123
    attention_output = tf.einsum(self._combine_equation, attention_scores,
124
                                 value)
Hongkun Yu's avatar
Hongkun Yu committed
125
    attention_output = self._output_dense(attention_output)
126
    if return_attention_scores:
127
      return attention_output, attention_scores, cache
Hongkun Yu's avatar
Hongkun Yu committed
128
    return attention_output, cache
Allen Wang's avatar
Allen Wang committed
129
130
131
132
133


def _rel_shift(x, klen=-1):
  """Performs relative shift to form the relative attention score."""

134
  x = tf.transpose(x, perm=[2, 3, 0, 1])
Allen Wang's avatar
Allen Wang committed
135
136
137
138
139
140
  x_size = tf.shape(x)

  x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
  x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
  x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
  x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
141
142

  x = tf.transpose(x, perm=[2, 3, 0, 1])
Allen Wang's avatar
Allen Wang committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

  return x


def _build_proj_equation(free_dims, bound_dims, output_dims):
  """Builds an einsum equation for projections inside multi-head attention."""
  input_str = ""
  kernel_str = ""
  output_str = ""
  bias_axes = ""
  letter_offset = 0
  for i in range(free_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    output_str += char

  letter_offset += free_dims
  for i in range(bound_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    kernel_str += char

  letter_offset += bound_dims
  for i in range(output_dims):
    char = _CHR_IDX[i + letter_offset]
    kernel_str += char
    output_str += char
    bias_axes += char
  equation = "%s,%s->%s" % (input_str, kernel_str, output_str)

  return equation, bias_axes, len(output_str)


def _get_output_shape(output_rank, known_last_dims):
  return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)


@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadRelativeAttention(MultiHeadAttention):
  """A multi-head attention layer with relative attention + position encoding.

  This layer shares the same input/output projections as the common
  MultiHeadAttention layer.

  When it calculates attention logits, position encoding is projected to form
  relative keys. The logits are composed by shifted relative logits and content
  logits.

  **Note: This layer is currently experimental.

  Arguments:
    num_heads: The number of attention heads.
    key_dim: Size of each attention head for query and key.
    value_dim: Size of attention head for value.
    dropout: Dropout probability for attention.
    use_bias: Boolean, whether the dense layers use bias vectors/matrices.
    kernel_initializer: Initializer for dense layer kernels.
    bias_initializer: Initializer for dense layer biases.
  Call args:
    query: Query `Tensor` of shape `[B, T, dim]`.
    value: Value `Tensor` of shape `[B, S, dim]`.
    content_attention_bias: Bias `Tensor` for content based attention of shape
      `[num_heads, dim]`.
    position_attention_bias: Bias `Tensor` for position based attention of shape
      `[num_heads, dim]`.
    relative_position_encoding: Relative positional encoding `Tensor` of shape
      `[B, L, dim]`.
    state: Optional `Tensor` of shape [B, M, E] where M is the length of the
      state or memory.
      If passed, this is also attended over as in Transformer XL.
    key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
      `value` for both `key` and `value`, which is the most common case.
    attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
      to certain positions.
  """

219
220
221
222
223
224
  def __init__(self,
               kernel_initializer="variance_scaling",
               **kwargs):
    super().__init__(kernel_initializer=kernel_initializer,
                     **kwargs)

Allen Wang's avatar
Allen Wang committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  def _build_from_signature(self, query, value, key=None):
    super(MultiHeadRelativeAttention, self)._build_from_signature(
        query=query,
        value=value,
        key=key)
    if hasattr(value, "shape"):
      value_shape = tf.TensorShape(value.shape)
    else:
      value_shape = value
    if key is None:
      key_shape = value_shape
    elif hasattr(key, "shape"):
      key_shape = tf.TensorShape(key.shape)
    else:
      key_shape = key

    common_kwargs = dict(
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint)

    with tf.init_scope():
251
      einsum_equation, _, output_rank = _build_proj_equation(
Allen Wang's avatar
Allen Wang committed
252
253
254
255
256
          key_shape.rank - 1, bound_dims=1, output_dims=2)
      self._encoding_dense = EinsumDense(
          einsum_equation,
          output_shape=_get_output_shape(output_rank - 1,
                                         [self._num_heads, self._key_dim]),
257
          bias_axes=None,
Allen Wang's avatar
Allen Wang committed
258
259
260
261
262
263
264
265
266
267
          name="encoding",
          **common_kwargs)

  def compute_attention(self,
                        query,
                        key,
                        value,
                        position,
                        content_attention_bias,
                        positional_attention_bias,
268
269
270
                        segment_matrix=None,
                        segment_encoding=None,
                        segment_attention_bias=None,
Allen Wang's avatar
Allen Wang committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                        attention_mask=None):
    """Computes the attention.

    This function defines the computation inside `call` with projected
    multihead Q, K, V, R inputs.

    Args:
      query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
      key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
      value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
      position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
      content_attention_bias: Trainable bias parameter added to the query head
        when calculating the content-based attention score.
      positional_attention_bias: Trainable bias parameter added to the query
        head when calculating the position-based attention score.
286
287
288
289
290
291
292
      segment_matrix: Optional `Tensor` representing segmentation IDs used in
        XLNet.
      segment_encoding: Optional trainable `Tensor` representing the
        segmentation encoding as used in XLNet.
      segment_attention_bias: Optional trainable bias parameter added to the
        query had when calculating the segment-based attention score used in
        XLNet.
Allen Wang's avatar
Allen Wang committed
293
294
295
296
297
298
      attention_mask: (default None) Optional mask that is added to attention
        logits. If state is not None, the mask source sequence dimension should
        extend M.

    Returns:
      attention_output: Multi-headed output of attention computation of shape
299
        `[B, S, N, key_dim]`.
Allen Wang's avatar
Allen Wang committed
300
301

    """
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    content_attention = tf.einsum(self._dot_product_equation,
                                  key,
                                  query + content_attention_bias)
    positional_attention = tf.einsum(self._dot_product_equation,
                                     position,
                                     query + positional_attention_bias)
    positional_attention = _rel_shift(
        positional_attention, klen=tf.shape(content_attention)[3])

    if segment_matrix is not None:
      segment_attention = tf.einsum("bind,snd->bnis",
                                    query + segment_attention_bias,
                                    segment_encoding)
      target_shape = tf.shape(positional_attention)
      segment_attention = tf.where(
          tf.broadcast_to(tf.expand_dims(segment_matrix, 1), target_shape),
          tf.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
          tf.broadcast_to(segment_attention[:, :, :, :1], target_shape))
      attention_sum = (
          content_attention + positional_attention + segment_attention)
    else:
      attention_sum = content_attention + positional_attention
Allen Wang's avatar
Allen Wang committed
324

325
326
    attention_scores = tf.multiply(
        attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
Allen Wang's avatar
Allen Wang committed
327

328
329
330
331
    # `attention_scores`: `[B, N, S, S + M]`
    if attention_mask is not None:
      attention_scores += (_large_compatible_negative(attention_scores.dtype)
                           * attention_mask)
Allen Wang's avatar
Allen Wang committed
332

333
    attention_scores = tf.nn.softmax(attention_scores, 3)
Allen Wang's avatar
Allen Wang committed
334
335
    attention_output = self._dropout_layer(attention_scores)

336
337
338
    attention_output = tf.einsum(self._combine_equation,
                                 attention_output,
                                 value)
Allen Wang's avatar
Allen Wang committed
339
340
341
342
343
344
345
346
347
    return attention_output

  def call(self,
           query,
           value,
           content_attention_bias,
           positional_attention_bias,
           key=None,
           relative_position_encoding=None,
348
349
350
           segment_matrix=None,
           segment_encoding=None,
           segment_attention_bias=None,
Allen Wang's avatar
Allen Wang committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
           state=None,
           attention_mask=None):
    """Compute multi-head relative attention over inputs.

    Size glossary:
      * Number of heads (H): the number of attention heads.
      * Value size (V): the size of each value embedding per head.
      * Key size (K): the size of each key embedding per head. Equally, the size
        of each query embedding per head. Typically K <= V.
      * Batch dimensions (B).
      * Query (target) attention axes shape (T).
      * Value (source) attention axes shape (S), the rank must match the target.
      * Encoding length (L): The relative positional encoding length.

    Args:
      query: attention input.
      value: attention input.
      content_attention_bias: A trainable bias parameter added to the query
        head when calculating the content-based attention score.
      positional_attention_bias: A trainable bias parameter added to the query
        head when calculating the position-based attention score.
      key: attention input.
      relative_position_encoding: relative positional encoding for key and
        value.
375
376
377
378
379
380
381
      segment_matrix: Optional `Tensor` representing segmentation IDs used in
        XLNet.
      segment_encoding: Optional `Tensor` representing the segmentation
        encoding as used in XLNet.
      segment_attention_bias: Optional trainable bias parameter added to the
        query had when calculating the segment-based attention score used in
        XLNet.
Allen Wang's avatar
Allen Wang committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
      state: (default None) optional state. If passed, this is also attended
        over as in TransformerXL.
      attention_mask: (default None) Optional mask that is added to attention
        logits. If state is not None, the mask source sequence dimension should
        extend M.

    Returns:
      attention_output: The result of the computation, of shape [B, T, E],
        where `T` is for target sequence shapes and `E` is the query input last
        dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
        are projected to the shape specified by `output_shape`.
    """
    if not self._built_from_signature:
      self._build_from_signature(query, value, key=key)
    if key is None:
      key = value
    if state is not None and state.shape.ndims > 1:
      value = tf.concat([state, value], 1)
      key = tf.concat([state, key], 1)

    # `query` = [B, T, N ,H]
    query = self._query_dense(query)

    # `key` = [B, S + M, N, H]
    key = self._key_dense(key)

    # `value` = [B, S + M, N, H]
    value = self._value_dense(value)

    # `position` = [B, L, N, H]
    position = self._encoding_dense(relative_position_encoding)

    attention_output = self.compute_attention(
        query=query,
        key=key,
        value=value,
        position=position,
        content_attention_bias=content_attention_bias,
        positional_attention_bias=positional_attention_bias,
421
422
423
        segment_matrix=segment_matrix,
        segment_encoding=segment_encoding,
        segment_attention_bias=segment_attention_bias,
Allen Wang's avatar
Allen Wang committed
424
        attention_mask=attention_mask)
425
426

    # `attention_output` = [B, S, N, H]
Allen Wang's avatar
Allen Wang committed
427
428
429
    attention_output = self._output_dense(attention_output)

    return attention_output