kernel_attention.py 27.8 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based kernel attention layer."""

import functools
import math
import tensorflow as tf

21
22
from official.modeling import tf_utils

23
24
25
_NUMERIC_STABLER = 1e-6


Frederick Liu's avatar
Frederick Liu committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class KernelMask(tf.keras.layers.Layer):
  """Creates kernel attention mask.

    inputs: from_tensor: 2D or 3D Tensor of shape
      [batch_size, from_seq_length, ...].
    mask: a Tensor of shape [batch_size, from_seq_length] which indicates
      which part of the inputs we should not attend.

    Returns:
      float Tensor of shape [batch_size, from_seq_length] that KernelAttention
      takes as mask.
  """

  def call(self, inputs, mask):
    mask = tf.cast(mask, inputs.dtype)
    return mask


Avi Dubey's avatar
Avi Dubey committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
  """Pads a tensor so that shape[axis] is divisible by chunk_length.

  Args:
    tensor: Input tensor to pad.
    axis: Axis to pad along.
    chunk_length: The output tensor will have shape[axis] divisible by
      chunk_length.
    pad: Pad the input tensor across the axis from left if pad="left", right if
      pad="right", or apply no padding if pad=None. In the latter case, the axis
      dimension of the input tensor must be divisible by the chunk_length.

  Returns:
    Padded tensor with shape[axis] divisible by chunk_length.
  """
  shape = tf.shape(tensor)
  rank = tf.rank(tensor)
  if axis < 0:
    axis += rank
  axis_length = shape[axis]
  pad_length = -axis_length % chunk_length
  if pad == "right":
    pad_width_2 = [[0, pad_length]]
  elif pad == "left":
    pad_width_2 = [[pad_length, 0]]
  else:
    if pad_length != 0:
      raise ValueError("When padding is not set, the axis dimension"
                       "has to be divisible by the chunk_length.")
    return tensor
  pad_width = tf.concat(
      [tf.zeros([axis, 2], dtype=tf.int32), pad_width_2,
       tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
  return tf.pad(tensor, pad_width)


def split_tensor_into_chunks(tensor, axis, chunk_length):
  """Reshape tensor along given axis using chunk_length.

  Args:
    tensor: Input tensor.
    axis: Reshape tensor along this axis.
    chunk_length: Split the axis into [axis/chunk_length, chunk_length]

  Returns:
    Reshaped tensor.
  """
  shape = tf.shape(tensor)
  num_chunks = shape[axis] // chunk_length
  new_shape = tf.concat(
      [shape[:axis], [num_chunks, chunk_length], shape[(axis+1):]], axis=0)
  return tf.reshape(tensor, new_shape)


def windowed_causal_performer_attention(query_matrix,
                                        key_matrix,
                                        value_matrix,
                                        chunk_length,
                                        window_length,
                                        pad="right"):
  """Applies windowed causal kernel attention with query, key, value tensors.

  We partition the T-length input sequence into N chunks, each of chunk_length
  tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
  (non-causal) Performers’ implicit attention and we model relationships between
  different chunks using Performers’ causal attention. We consider windowed
  causal variant of performer, where the current chunk attends only to the
  window of window_length of the most recent chunks.

  Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
  attention is computed between the pair while 0 indicates attention is not
  computed between the pairs:
  111000000
  111000000
  111000000
  111111000
  111111000
  111111000
  000111111
  000111111
  000111111

  User can ensure sequence_length is divisible by chunk_length or use
  pad="left"/"right" to pad the sequence length either at the top or bottom
  respectively and make it divisible by chunk_length.

  Args:
    query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
    key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
    value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
    chunk_length: Length of each chunk in tokens.
    window_length: Length of attention window in chunks.
    pad: Pad the query, value and key input tensors across the T dimension from
      left if pad="left", right if pad="right", or apply no padding if pad=None.
      In the latter case, the T dimension of the input tensors must be divisible
      by the chunk_length.

  Returns:
    Window causal performer attention of shape `[B, T, N, out_dim]`.
  """

  old_shape = tf.shape(value_matrix)

  query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, pad)
  key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, pad)
  value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, pad)

  new_shape = tf.shape(value_matrix)
  chunked_query_matrix = split_tensor_into_chunks(
      query_matrix, -3,
      chunk_length)  # [-1, T//chunk_length, chunk_length, N, dim]
  chunked_key_matrix = split_tensor_into_chunks(
      key_matrix, -3,
      chunk_length)  # [-1, T//chunk_length, chunk_length, N, dim]
  chunked_value_matrix = split_tensor_into_chunks(
      value_matrix, -3,
      chunk_length)  # [-1, T//chunk_length, chunk_length, N, out_dim]

  kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix,
                   chunked_value_matrix)
  kp_v_cumsum = tf.cumsum(kp_v, axis=-4)
  kp_v_winsum = kp_v_cumsum - tf.pad(
      kp_v_cumsum,
      [[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
  numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum)

  k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3)
  k_cumsum = tf.cumsum(k_sum, axis=-3)
  k_winsum = k_cumsum - tf.pad(k_cumsum, [[0, 0], [window_length, 0], [0, 0],
                                          [0, 0]])[:, :-window_length]
  denominator = tf.einsum("BNCHD,BNHD->BNCH", chunked_query_matrix, k_winsum)
  denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER

  attention = numerator / denominator
  attention = tf.reshape(attention, new_shape)

  start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
  attention = tf.slice(attention, start, old_shape)

  return attention


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def create_projection_matrix(m, d, seed=None):
  r"""Constructs the matrix of random projections.

  Constructs a matrix of random orthogonal projections. Each projection vector
  has direction chosen uniformly at random length taken from the
  \chi(d) distribution.).

  Args:
    m: number of random projections.
    d: dimensionality of each random projection.
    seed: random seed used to construct projections. If not, we use the stateful
      api.

  Returns:
    The matrix of random projections of the shape [m, d].
  """
  nb_full_blocks = math.ceil(m / d)
203
204
  block_list = tf.TensorArray(
      tf.float32, size=tf.cast(nb_full_blocks, dtype=tf.int32))
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
  stateful = False
  if seed is None:
    stateful = True
    # dummy seed to make sure the graph compiles though the path is not taken.
    seed = tf.constant([0, 1])
  current_seed = seed
  for i in range(nb_full_blocks):
    if stateful:
      unstructured_block = tf.random.normal((d, d))
    else:
      unstructured_block = tf.random.stateless_normal((d, d), seed=current_seed)
      current_seed = tf.random.stateless_uniform([2],
                                                 seed=current_seed,
                                                 minval=None,
                                                 dtype=tf.int32)
    q, _ = tf.linalg.qr(unstructured_block)
    q = tf.transpose(q)
    block_list = block_list.write(i, q)
  final_matrix = block_list.concat()[:m]
  if stateful is None:
    multiplier = tf.norm(tf.random.normal((m, d)), axis=1)
  else:
    multiplier = tf.norm(
        tf.random.stateless_normal((m, d), seed=current_seed), axis=1)
  return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)


Frederick Liu's avatar
Frederick Liu committed
232
def _generalized_kernel(x, projection_matrix, f, h):
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
  """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.

  Args:
    x: The feature being transformed with shape [B, T, N ,H].
    projection_matrix: The matrix with shape [M, H] that we projecct x to, where
      M is the number of projections.
    f: A non-linear function applied on x or projected x.
    h: A muliplier which is a function of x applied after projected and
      transformed. Only applied if projection_matrix is not None.

  Returns:
    Transformed feature.
  """

  if projection_matrix is None:
    return h(x) * f(x)
  else:
    x_projected = tf.einsum("BTNH,MH->BTNM", x, projection_matrix)
    return h(x) * f(x_projected) / tf.math.sqrt(
        tf.cast(tf.shape(projection_matrix)[0], tf.float32))


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def expplus(data_orig,
            other_data,
            is_query,
            projection_matrix=None,
            numerical_stabilizer=0.000001,
            normalize_data=True,
            numerical_renormalizer=True,
            extra_renormalize_exp_fun=False):
  """FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 .

  Args:
    data_orig: data tensor of shape [B,T,H,D] for which random features aree to
      be computed
    other_data: additional tensor of the shape [B,F,H,D] used to collect stats
      to determine the exact instantiation of the random feature mechanism
    is_query: boolean indicating whether <data_orig> tensor is a query tensor
    projection_matrix: tensor of the shape [M,D] encoding random projections for
      random features (M stands for the number of random features)
    numerical_stabilizer: numerical stabilizer for the kernel features
    normalize_data: whether to sqrt-d-normalize queries/keys as in the regular
      attention
    numerical_renormalizer: whether to apply additional renormalization for
      numerical stability
    extra_renormalize_exp_fun: extra renormalizer for the exponential mapping
      applied to construct random features

  Returns:
    Random feature map tensor for the unbiased softmax-kernel estimation.
  """

  data = data_orig
  if projection_matrix is None:
    return data_orig
  projection_matrix = tf.cast(projection_matrix, data.dtype)
  if normalize_data:
    data_normalizer = 1.0 / tf.math.sqrt(
        (tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype))))
  else:
    data_normalizer = 1.0
    lengths = tf.math.square(data)
    lengths = tf.reduce_sum(lengths, axis=tf.keras.backend.ndim(data) - 1)
    lengths = tf.expand_dims(lengths, axis=tf.keras.backend.ndim(data) - 1)
    lengths = tf.math.sqrt(lengths)
    data /= lengths
  ratio = 1.0 / tf.math.sqrt(
      tf.dtypes.cast(projection_matrix.shape[0], data.dtype))
  data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data,
                        projection_matrix)
  diag_data = tf.math.square(data)
  diag_data = tf.math.reduce_sum(
      diag_data, axis=tf.keras.backend.ndim(data) - 1)
  diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
  diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)

  # Calculating coefficients A, B of the FAVOR++ mechanism:
  _, l, _, _ = tf_utils.get_shape_list(data_orig)

  l = tf.cast(l, dtype=tf.float32)
  first_sum_of_squares = tf.math.square(data)
  first_sum_of_squares = tf.math.reduce_sum(
      first_sum_of_squares, axis=(1, -1), keepdims=True)
  first_sum_of_squares *= (data_normalizer * data_normalizer)
  first_sum_of_squares /= l  # data.shape[1]
  second_sum_of_squares = tf.math.square(other_data)
  second_sum_of_squares = tf.math.reduce_sum(
      second_sum_of_squares, axis=(1, -1), keepdims=True)
  second_sum_of_squares *= (data_normalizer * data_normalizer)
  second_sum_of_squares /= l  #  other_data.shape[1]
  data_sum = tf.math.reduce_sum(data, axis=(1,), keepdims=True)
  other_data_sum = tf.math.reduce_sum(other_data, axis=(1,), keepdims=True)
  d_prod = tf.einsum("blhd,blhd->blh", data_sum, other_data_sum)
  d_prod = tf.expand_dims(d_prod, axis=-1)
  d_prod *= (data_normalizer * data_normalizer)
  d_prod *= (2.0 / (l * l))
  ave = first_sum_of_squares + second_sum_of_squares + d_prod
  dim = projection_matrix.shape[-1]
331
  a_coeff = (1.0 / (4.0 * ave)) * (
332
333
      tf.math.sqrt((2.0 * ave + dim) *
                   (2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
334
335
336
337
338
339
  a_coeff = (1.0 - 1.0 / a_coeff) / 8.0
  b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff)
  d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0)
  a_coeff = tf.stop_gradient(a_coeff)
  b_coeff = tf.stop_gradient(b_coeff)
  d_coeff = tf.stop_gradient(d_coeff)
340
341
342
343
344
345
346
347

  # Calculating diag_omega for the FAVOR++ mechanism:
  diag_omega = tf.math.square(projection_matrix)
  diag_omega = tf.math.reduce_sum(
      diag_omega, axis=tf.keras.backend.ndim(projection_matrix) - 1)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
348
  diag_omega = a_coeff * diag_omega
349
350
351
352
353
  #

  if numerical_renormalizer:
    if is_query:
      last_dims_t = (len(data_dash.shape) - 1,)
354
355
      stab = b_coeff * tf.math.reduce_max(
          data_dash, axis=last_dims_t, keepdims=True)
356
    else:
357
      stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True)
358
359
360
    if extra_renormalize_exp_fun:
      extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
      stab = tf.math.maximum(stab, extra_stab)
361
362
    data_dash = ratio * d_coeff * (
        tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) +
363
364
        numerical_stabilizer)
  else:
365
366
    data_dash = ratio * d_coeff * (
        tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) +
367
368
369
370
371
        numerical_stabilizer)

  return data_dash


372
373
374
375
376
377
378
379
380
# pylint: disable=g-long-lambda
_TRANSFORM_MAP = {
    "elu":
        functools.partial(
            _generalized_kernel,
            f=lambda x: tf.keras.activations.elu(x) + 1,
            h=lambda x: 1),
    "relu":
        functools.partial(
381
382
383
            _generalized_kernel,
            # Improve numerical stability and avoid NaNs in some cases by adding
            # a tiny epsilon.
384
385
            f=lambda x: tf.keras.activations.relu(x) + 1e-3,
            h=lambda x: 1),
386
    "square":
387
        functools.partial(_generalized_kernel, f=tf.math.square, h=lambda x: 1),
388
389
390
391
    "exp":
        functools.partial(
            _generalized_kernel,
            # Avoid exp explosion by shifting.
392
393
394
395
396
            f=lambda x: tf.math.exp(x - tf.math.reduce_max(
                x, axis=[1, 2, 3], keepdims=True)),
            h=lambda x: tf.math.exp(-0.5 * tf.math.reduce_sum(
                tf.math.square(x), axis=-1, keepdims=True)),
        ),
397
398
399
400
    "expmod":
        functools.partial(
            _generalized_kernel,
            # Avoid exp explosion by shifting.
401
402
403
404
405
406
407
            f=lambda x: tf.math.exp(x - tf.math.reduce_max(
                x, axis=[1, 2, 3], keepdims=True)),
            h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
                tf.cast(tf.shape(x)[-1], tf.float32))),
        ),
    "identity":
        functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
408
409
410
411
412
413
414
415
416
417
418
}
# pylint: enable=g-long-lambda


class KernelAttention(tf.keras.layers.MultiHeadAttention):
  """A variant of efficient transformers which replaces softmax with kernels.

  This module combines ideas from the two following papers:

  Rethinking Attention with Performers
  (https://arxiv.org/abs/2009.14794)
Frederick Liu's avatar
Frederick Liu committed
419
  - exp (Lemma 1, positive), relu
420
  - random/deterministic projection
421
422
423
  Chefs' Random Tables: Non-Trigonometric Random Features
  (https://arxiv.org/abs/2205.15317)
  - expplus (OPRF mechanism)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

  Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
  (https://arxiv.org/abs/2006.16236)
  - elu

  with the theory of approximating angular Performer kernels from go/performer.

  The module enables computing efficient attention in both: long sequence and
  shorter sequence regimes. In the former setting, the attention matrix is never
  explicitly computed and instead its low-rank decomposition obtained with given
  kernel feature maps is leveraged to conduct attention module calculations
  (see: https://arxiv.org/abs/2006.16236). In the latter setting, attention
  matrix is constructed, but kernel features providing dimensionality reduction
  are applied, resulting in more efficient computation of the attention matrix.
  """

  def __init__(self,
               feature_transform="exp",
               num_random_features=256,
               seed=0,
               redraw=False,
               is_short_seq=False,
               begin_kernel=0,
Frederick Liu's avatar
Frederick Liu committed
447
               scale=None,
Jialu Liu's avatar
Jialu Liu committed
448
               scale_by_length=False,
Avi Dubey's avatar
Avi Dubey committed
449
450
451
               use_windowed_causal=False,
               chunk_length=1,
               window_length=3,
452
453
454
455
               **kwargs):
    r"""Constructor of KernelAttention.

    Args:
Jialu Liu's avatar
Jialu Liu committed
456
      feature_transform: A non-linear transform of the keys and quries. Possible
457
458
        transforms are "elu", "relu", "square", "exp", "expplus", "expmod",
        "identity".
459
460
461
462
463
464
465
466
467
      num_random_features: Number of random features to be used for projection.
        if num_random_features <= 0, no production is used before transform.
      seed: The seed to begin drawing random features. Once the seed is set, the
        psedo number generation is determinisitc. Users should pass different
        seed for different layers. For multi-worker, each layer will use the
        same projection at each step.
      redraw: Whether to redraw projection every forward pass during training.
        The argument is only effective when num_random_features > 0.
      is_short_seq: boolean predicate indicating whether input data consists of
Jialu Liu's avatar
Jialu Liu committed
468
469
        very short sequences or not; in most cases this should be False (default
        option).
470
471
      begin_kernel: Apply kernel_attention after this sequence id and apply
        softmax attention before this.
Frederick Liu's avatar
Frederick Liu committed
472
473
      scale: The value to scale the dot product as described in `Attention Is
        All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
Jialu Liu's avatar
Jialu Liu committed
474
475
476
477
      scale_by_length: boolean predicate indicating whether additionally scale
        the dot product based on key length. Set as log_512^(n) to stablize
        attention entropy against length. Refer to
        https://kexue.fm/archives/8823 for details.
Avi Dubey's avatar
Avi Dubey committed
478
479
480
481
      use_windowed_causal: If true perform windowed causal attention. See
        windowed_causal_performer_attention function docstring for more details.
      chunk_length: Length of each chunk in tokens.
      window_length: Length of attention window in chunks.
482
483
      **kwargs: The same arguments `MultiHeadAttention` layer.
    """
Avi Dubey's avatar
Avi Dubey committed
484
485
    if (feature_transform not in _TRANSFORM_MAP and
        feature_transform != "expplus"):
486
487
488
489
490
491
492
493
494
495
496
      raise ValueError("Unsupported feature_transform. The supported "
                       "feature_transform are %s. "
                       "Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
    if num_random_features <= 0 and redraw:
      raise ValueError(
          "There is nothing to redraw when num_random_features <= 0.")
    self._feature_transform = feature_transform
    self._num_random_features = num_random_features
    self._redraw = redraw
    self._is_short_seq = is_short_seq
    self._begin_kernel = begin_kernel
Jialu Liu's avatar
Jialu Liu committed
497
    self._scale_by_length = scale_by_length
498
499
500
501
502
    # We use the seed for two scenarios:
    # 1. inference
    # 2. no redraw
    self._seed = seed
    super().__init__(**kwargs)
Frederick Liu's avatar
Frederick Liu committed
503
504
505
506
    if scale is None:
      self._scale = 1.0 / math.sqrt(float(self._key_dim))
    else:
      self._scale = scale
507
508
509
510
511
    self._projection_matrix = None
    if num_random_features > 0:
      self._projection_matrix = create_projection_matrix(
          self._num_random_features, self._key_dim,
          tf.constant([self._seed, self._seed + 1]))
Avi Dubey's avatar
Avi Dubey committed
512
513
514
515
516
517
    self.use_windowed_causal = use_windowed_causal
    self.chunk_length = chunk_length
    self.window_length = window_length
    if self.use_windowed_causal and self._is_short_seq:
      raise ValueError(
          "use_windowed_causal and short_seq methods are mutually exclusive")
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541

  def _compute_attention(self,
                         query,
                         key,
                         value,
                         feature_transform,
                         is_short_seq,
                         attention_mask=None,
                         training=False,
                         numeric_stabler=_NUMERIC_STABLER):
    """Applies kernel attention with query, key, value tensors.

    This function defines the computation inside `call` with projected
    multi-head Q, K, V inputs. Users can override this function for customized
    attention implementation.

    Args:
      query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
      key: Projected key `Tensor` of shape `[B, S, N, key_dim]`.
      value: Projected value `Tensor` of shape `[B, S, N, value_dim]`.
      feature_transform: A non-linear transform of the keys and quries.
      is_short_seq: boolean predicate indicating whether input data consists of
        short or long sequences; usually short sequence is defined as having
        length L <= 1024.
Jialu Liu's avatar
Jialu Liu committed
542
543
544
      attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
        to masked positions. Note that the mask is only appied to the keys. User
        may want to mask the output if query contains pads.
545
546
547
548
549
550
551
552
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
      numeric_stabler: A scalar value added to avoid divide by 0.

    Returns:
      attention_output: Multi-headed outputs of attention computation.
    """
    projection_matrix = None
Avi Dubey's avatar
Avi Dubey committed
553

554
555
556
557
558
559
560
    if self._num_random_features > 0:
      if self._redraw and training:
        projection_matrix = create_projection_matrix(self._num_random_features,
                                                     self._key_dim)
      else:
        projection_matrix = self._projection_matrix

Jialu Liu's avatar
Jialu Liu committed
561
562
563
564
565
566
    if self._scale_by_length:
      scale = tf.math.log(tf.reduce_sum(attention_mask,
                                        axis=-1)) * self._scale / math.log(512)
      scale = tf.reshape(scale, [-1, 1, 1, 1])
    else:
      scale = self._scale
567
568
569
570
    if is_short_seq:
      # Note: Applying scalar multiply at the smaller end of einsum improves
      # XLA performance, but may introduce slight numeric differences in
      # the Transformer attention head.
Jialu Liu's avatar
Jialu Liu committed
571
      query = query * scale
572
573
574
575
    else:
      # Note: we suspect spliting the scale to key, query yields smaller
      # approximation variance when random projection is used.
      # For simplicity, we also split when there's no random projection.
Jialu Liu's avatar
Jialu Liu committed
576
577
      key *= tf.math.sqrt(scale)
      query *= tf.math.sqrt(scale)
Frederick Liu's avatar
Frederick Liu committed
578

579
580
581
582
583
584
    if feature_transform != "expplus":
      key_prime = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
      query_prime = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
    else:
      key_prime = expplus(key, query, False, projection_matrix)
      query_prime = expplus(query, key, True, projection_matrix)
585
586

    if attention_mask is not None:
587
      key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
588

589
    if is_short_seq:
590
      attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
591
592
      attention_scores = tf.nn.softmax(attention_scores, axis=2)
      attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
Avi Dubey's avatar
Avi Dubey committed
593
594
595
    elif self.use_windowed_causal:
      attention_output = windowed_causal_performer_attention(
          query_prime, key_prime, value, self.chunk_length, self.window_length)
596
    else:
597
      kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
598
      denominator = 1.0 / (
599
600
601
602
          tf.einsum("BTNH,BNH->BTN", query_prime,
                    tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER)
      attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv,
                                   denominator)
603
    return attention_output
604
605

  def _build_from_signature(self, query, value, key=None):
Rebecca Chen's avatar
Rebecca Chen committed
606
    super()._build_from_signature(query=query, value=value, key=key)  # pytype: disable=attribute-error  # typed-keras
607
608
609
610
611
612
613
614
615
616
    if self._begin_kernel > 0:
      common_kwargs = dict(
          kernel_initializer=self._kernel_initializer,
          bias_initializer=self._bias_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activity_regularizer=self._activity_regularizer,
          kernel_constraint=self._kernel_constraint,
          bias_constraint=self._bias_constraint)
      self._output_dense_softmax = self._make_output_dense(
617
618
          self._query_shape.rank - 1,
          common_kwargs,
619
620
621
          name="attention_output_softmax")
      self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)

622
  def call(self, query, value, key=None, attention_mask=None, training=False):
Frederick Liu's avatar
Frederick Liu committed
623
624
625
626
627
628
629
    """Compute attention with kernel mechanism.

    Args:
      query: Query `Tensor` of shape `[B, T, dim]`.
      value: Value `Tensor` of shape `[B, S, dim]`.
      key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
        `value` for both `key` and `value`, which is the most common case.
Jialu Liu's avatar
Jialu Liu committed
630
631
632
      attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
        to masked positions. Note that the mask is only appied to the keys. User
        may want to mask the output if query contains pads.
Frederick Liu's avatar
Frederick Liu committed
633
634
635
636
637
638
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).

    Returns:
      Multi-headed outputs of attention computation.
    """
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value

    #   N = `num_attention_heads`
    #   H = `size_per_head`
    # `query` = [B, T, N ,H]
    query = self._query_dense(query)

    # `key` = [B, S, N, H]
    key = self._key_dense(key)

    # `value` = [B, S, N, D]
    value = self._value_dense(value)

    if self._begin_kernel > 0:
      attention_output_softmax = self._compute_attention(
657
658
          query[:, :self._begin_kernel], key, value, "identity", True,
          attention_mask, training)
659
660
661
662
663
      attention_output_softmax = self._dropout_softmax(attention_output_softmax)
      attention_output_softmax = self._output_dense_softmax(
          attention_output_softmax)

      attention_output_kernel = self._compute_attention(
664
665
          query[:, self._begin_kernel:], key, value, self._feature_transform,
          self._is_short_seq, attention_mask, training)
666
      attention_output_kernel = self._dropout_layer(attention_output_kernel)
667
      attention_output_kernel = self._output_dense(attention_output_kernel)
668
669
670
      attention_output = tf.concat(
          [attention_output_softmax, attention_output_kernel], axis=1)
    else:
Jialu Liu's avatar
Jialu Liu committed
671
672
673
674
      attention_output = self._compute_attention(query, key, value,
                                                 self._feature_transform,
                                                 self._is_short_seq,
                                                 attention_mask, training)
675
676
677
678
679
680
681
682
683
684
685
686
687
688
      # This is actually dropping out entire tokens to attend to, which might
      # seem a bit unusual, but is taken from the original Transformer paper.
      attention_output = self._dropout_layer(attention_output)
      attention_output = self._output_dense(attention_output)
    return attention_output

  def get_config(self):
    config = {
        "feature_transform": self._feature_transform,
        "num_random_features": self._num_random_features,
        "seed": self._seed,
        "redraw": self._redraw,
        "is_short_seq": self._is_short_seq,
        "begin_kernel": self._begin_kernel,
Frederick Liu's avatar
Frederick Liu committed
689
        "scale": self._scale,
690
691
692
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))