kernel_attention.py 21.6 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based kernel attention layer."""

import functools
import math
import tensorflow as tf

21
22
from official.modeling import tf_utils

23
24
25
_NUMERIC_STABLER = 1e-6


Frederick Liu's avatar
Frederick Liu committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class KernelMask(tf.keras.layers.Layer):
  """Creates kernel attention mask.

    inputs: from_tensor: 2D or 3D Tensor of shape
      [batch_size, from_seq_length, ...].
    mask: a Tensor of shape [batch_size, from_seq_length] which indicates
      which part of the inputs we should not attend.

    Returns:
      float Tensor of shape [batch_size, from_seq_length] that KernelAttention
      takes as mask.
  """

  def call(self, inputs, mask):
    mask = tf.cast(mask, inputs.dtype)
    return mask


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def create_projection_matrix(m, d, seed=None):
  r"""Constructs the matrix of random projections.

  Constructs a matrix of random orthogonal projections. Each projection vector
  has direction chosen uniformly at random length taken from the
  \chi(d) distribution.).

  Args:
    m: number of random projections.
    d: dimensionality of each random projection.
    seed: random seed used to construct projections. If not, we use the stateful
      api.

  Returns:
    The matrix of random projections of the shape [m, d].
  """
  nb_full_blocks = math.ceil(m / d)
61
62
  block_list = tf.TensorArray(
      tf.float32, size=tf.cast(nb_full_blocks, dtype=tf.int32))
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  stateful = False
  if seed is None:
    stateful = True
    # dummy seed to make sure the graph compiles though the path is not taken.
    seed = tf.constant([0, 1])
  current_seed = seed
  for i in range(nb_full_blocks):
    if stateful:
      unstructured_block = tf.random.normal((d, d))
    else:
      unstructured_block = tf.random.stateless_normal((d, d), seed=current_seed)
      current_seed = tf.random.stateless_uniform([2],
                                                 seed=current_seed,
                                                 minval=None,
                                                 dtype=tf.int32)
    q, _ = tf.linalg.qr(unstructured_block)
    q = tf.transpose(q)
    block_list = block_list.write(i, q)
  final_matrix = block_list.concat()[:m]
  if stateful is None:
    multiplier = tf.norm(tf.random.normal((m, d)), axis=1)
  else:
    multiplier = tf.norm(
        tf.random.stateless_normal((m, d), seed=current_seed), axis=1)
  return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)


Frederick Liu's avatar
Frederick Liu committed
90
def _generalized_kernel(x, projection_matrix, f, h):
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.

  Args:
    x: The feature being transformed with shape [B, T, N ,H].
    projection_matrix: The matrix with shape [M, H] that we projecct x to, where
      M is the number of projections.
    f: A non-linear function applied on x or projected x.
    h: A muliplier which is a function of x applied after projected and
      transformed. Only applied if projection_matrix is not None.

  Returns:
    Transformed feature.
  """

  if projection_matrix is None:
    return h(x) * f(x)
  else:
    x_projected = tf.einsum("BTNH,MH->BTNM", x, projection_matrix)
    return h(x) * f(x_projected) / tf.math.sqrt(
        tf.cast(tf.shape(projection_matrix)[0], tf.float32))


113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def expplus(data_orig,
            other_data,
            is_query,
            projection_matrix=None,
            numerical_stabilizer=0.000001,
            normalize_data=True,
            numerical_renormalizer=True,
            extra_renormalize_exp_fun=False):
  """FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 .

  Args:
    data_orig: data tensor of shape [B,T,H,D] for which random features aree to
      be computed
    other_data: additional tensor of the shape [B,F,H,D] used to collect stats
      to determine the exact instantiation of the random feature mechanism
    is_query: boolean indicating whether <data_orig> tensor is a query tensor
    projection_matrix: tensor of the shape [M,D] encoding random projections for
      random features (M stands for the number of random features)
    numerical_stabilizer: numerical stabilizer for the kernel features
    normalize_data: whether to sqrt-d-normalize queries/keys as in the regular
      attention
    numerical_renormalizer: whether to apply additional renormalization for
      numerical stability
    extra_renormalize_exp_fun: extra renormalizer for the exponential mapping
      applied to construct random features

  Returns:
    Random feature map tensor for the unbiased softmax-kernel estimation.
  """

  data = data_orig
  if projection_matrix is None:
    return data_orig
  projection_matrix = tf.cast(projection_matrix, data.dtype)
  if normalize_data:
    data_normalizer = 1.0 / tf.math.sqrt(
        (tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype))))
  else:
    data_normalizer = 1.0
    lengths = tf.math.square(data)
    lengths = tf.reduce_sum(lengths, axis=tf.keras.backend.ndim(data) - 1)
    lengths = tf.expand_dims(lengths, axis=tf.keras.backend.ndim(data) - 1)
    lengths = tf.math.sqrt(lengths)
    data /= lengths
  ratio = 1.0 / tf.math.sqrt(
      tf.dtypes.cast(projection_matrix.shape[0], data.dtype))
  data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data,
                        projection_matrix)
  diag_data = tf.math.square(data)
  diag_data = tf.math.reduce_sum(
      diag_data, axis=tf.keras.backend.ndim(data) - 1)
  diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
  diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)

  # Calculating coefficients A, B of the FAVOR++ mechanism:
  _, l, _, _ = tf_utils.get_shape_list(data_orig)

  l = tf.cast(l, dtype=tf.float32)
  first_sum_of_squares = tf.math.square(data)
  first_sum_of_squares = tf.math.reduce_sum(
      first_sum_of_squares, axis=(1, -1), keepdims=True)
  first_sum_of_squares *= (data_normalizer * data_normalizer)
  first_sum_of_squares /= l  # data.shape[1]
  second_sum_of_squares = tf.math.square(other_data)
  second_sum_of_squares = tf.math.reduce_sum(
      second_sum_of_squares, axis=(1, -1), keepdims=True)
  second_sum_of_squares *= (data_normalizer * data_normalizer)
  second_sum_of_squares /= l  #  other_data.shape[1]
  data_sum = tf.math.reduce_sum(data, axis=(1,), keepdims=True)
  other_data_sum = tf.math.reduce_sum(other_data, axis=(1,), keepdims=True)
  d_prod = tf.einsum("blhd,blhd->blh", data_sum, other_data_sum)
  d_prod = tf.expand_dims(d_prod, axis=-1)
  d_prod *= (data_normalizer * data_normalizer)
  d_prod *= (2.0 / (l * l))
  ave = first_sum_of_squares + second_sum_of_squares + d_prod
  dim = projection_matrix.shape[-1]
189
  a_coeff = (1.0 / (4.0 * ave)) * (
190
191
      tf.math.sqrt((2.0 * ave + dim) *
                   (2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
192
193
194
195
196
197
  a_coeff = (1.0 - 1.0 / a_coeff) / 8.0
  b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff)
  d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0)
  a_coeff = tf.stop_gradient(a_coeff)
  b_coeff = tf.stop_gradient(b_coeff)
  d_coeff = tf.stop_gradient(d_coeff)
198
199
200
201
202
203
204
205

  # Calculating diag_omega for the FAVOR++ mechanism:
  diag_omega = tf.math.square(projection_matrix)
  diag_omega = tf.math.reduce_sum(
      diag_omega, axis=tf.keras.backend.ndim(projection_matrix) - 1)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
  diag_omega = tf.expand_dims(diag_omega, axis=0)
206
  diag_omega = a_coeff * diag_omega
207
208
209
210
211
  #

  if numerical_renormalizer:
    if is_query:
      last_dims_t = (len(data_dash.shape) - 1,)
212
213
      stab = b_coeff * tf.math.reduce_max(
          data_dash, axis=last_dims_t, keepdims=True)
214
    else:
215
      stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True)
216
217
218
    if extra_renormalize_exp_fun:
      extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
      stab = tf.math.maximum(stab, extra_stab)
219
220
    data_dash = ratio * d_coeff * (
        tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) +
221
222
        numerical_stabilizer)
  else:
223
224
    data_dash = ratio * d_coeff * (
        tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) +
225
226
227
228
229
        numerical_stabilizer)

  return data_dash


230
231
232
233
234
235
236
237
238
# pylint: disable=g-long-lambda
_TRANSFORM_MAP = {
    "elu":
        functools.partial(
            _generalized_kernel,
            f=lambda x: tf.keras.activations.elu(x) + 1,
            h=lambda x: 1),
    "relu":
        functools.partial(
239
240
241
            _generalized_kernel,
            # Improve numerical stability and avoid NaNs in some cases by adding
            # a tiny epsilon.
242
243
            f=lambda x: tf.keras.activations.relu(x) + 1e-3,
            h=lambda x: 1),
244
    "square":
245
        functools.partial(_generalized_kernel, f=tf.math.square, h=lambda x: 1),
246
247
248
249
    "exp":
        functools.partial(
            _generalized_kernel,
            # Avoid exp explosion by shifting.
250
251
252
253
254
            f=lambda x: tf.math.exp(x - tf.math.reduce_max(
                x, axis=[1, 2, 3], keepdims=True)),
            h=lambda x: tf.math.exp(-0.5 * tf.math.reduce_sum(
                tf.math.square(x), axis=-1, keepdims=True)),
        ),
255
256
257
258
    "expmod":
        functools.partial(
            _generalized_kernel,
            # Avoid exp explosion by shifting.
259
260
261
262
263
264
265
            f=lambda x: tf.math.exp(x - tf.math.reduce_max(
                x, axis=[1, 2, 3], keepdims=True)),
            h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
                tf.cast(tf.shape(x)[-1], tf.float32))),
        ),
    "identity":
        functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
266
267
268
269
270
271
272
273
274
275
276
}
# pylint: enable=g-long-lambda


class KernelAttention(tf.keras.layers.MultiHeadAttention):
  """A variant of efficient transformers which replaces softmax with kernels.

  This module combines ideas from the two following papers:

  Rethinking Attention with Performers
  (https://arxiv.org/abs/2009.14794)
Frederick Liu's avatar
Frederick Liu committed
277
  - exp (Lemma 1, positive), relu
278
  - random/deterministic projection
279
280
281
  Chefs' Random Tables: Non-Trigonometric Random Features
  (https://arxiv.org/abs/2205.15317)
  - expplus (OPRF mechanism)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

  Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
  (https://arxiv.org/abs/2006.16236)
  - elu

  with the theory of approximating angular Performer kernels from go/performer.

  The module enables computing efficient attention in both: long sequence and
  shorter sequence regimes. In the former setting, the attention matrix is never
  explicitly computed and instead its low-rank decomposition obtained with given
  kernel feature maps is leveraged to conduct attention module calculations
  (see: https://arxiv.org/abs/2006.16236). In the latter setting, attention
  matrix is constructed, but kernel features providing dimensionality reduction
  are applied, resulting in more efficient computation of the attention matrix.
  """

  def __init__(self,
               feature_transform="exp",
               num_random_features=256,
               seed=0,
               redraw=False,
               is_short_seq=False,
               begin_kernel=0,
Frederick Liu's avatar
Frederick Liu committed
305
               scale=None,
Jialu Liu's avatar
Jialu Liu committed
306
               scale_by_length=False,
307
308
309
310
               **kwargs):
    r"""Constructor of KernelAttention.

    Args:
Jialu Liu's avatar
Jialu Liu committed
311
      feature_transform: A non-linear transform of the keys and quries. Possible
312
313
        transforms are "elu", "relu", "square", "exp", "expplus", "expmod",
        "identity".
314
315
316
317
318
319
320
321
322
      num_random_features: Number of random features to be used for projection.
        if num_random_features <= 0, no production is used before transform.
      seed: The seed to begin drawing random features. Once the seed is set, the
        psedo number generation is determinisitc. Users should pass different
        seed for different layers. For multi-worker, each layer will use the
        same projection at each step.
      redraw: Whether to redraw projection every forward pass during training.
        The argument is only effective when num_random_features > 0.
      is_short_seq: boolean predicate indicating whether input data consists of
Jialu Liu's avatar
Jialu Liu committed
323
324
        very short sequences or not; in most cases this should be False (default
        option).
325
326
      begin_kernel: Apply kernel_attention after this sequence id and apply
        softmax attention before this.
Frederick Liu's avatar
Frederick Liu committed
327
328
      scale: The value to scale the dot product as described in `Attention Is
        All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
Jialu Liu's avatar
Jialu Liu committed
329
330
331
332
      scale_by_length: boolean predicate indicating whether additionally scale
        the dot product based on key length. Set as log_512^(n) to stablize
        attention entropy against length. Refer to
        https://kexue.fm/archives/8823 for details.
333
334
      **kwargs: The same arguments `MultiHeadAttention` layer.
    """
335
    if feature_transform not in _TRANSFORM_MAP and feature_transform != "expplus":
336
337
338
339
340
341
342
343
344
345
346
      raise ValueError("Unsupported feature_transform. The supported "
                       "feature_transform are %s. "
                       "Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
    if num_random_features <= 0 and redraw:
      raise ValueError(
          "There is nothing to redraw when num_random_features <= 0.")
    self._feature_transform = feature_transform
    self._num_random_features = num_random_features
    self._redraw = redraw
    self._is_short_seq = is_short_seq
    self._begin_kernel = begin_kernel
Jialu Liu's avatar
Jialu Liu committed
347
    self._scale_by_length = scale_by_length
348
349
350
351
352
    # We use the seed for two scenarios:
    # 1. inference
    # 2. no redraw
    self._seed = seed
    super().__init__(**kwargs)
Frederick Liu's avatar
Frederick Liu committed
353
354
355
356
    if scale is None:
      self._scale = 1.0 / math.sqrt(float(self._key_dim))
    else:
      self._scale = scale
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    self._projection_matrix = None
    if num_random_features > 0:
      self._projection_matrix = create_projection_matrix(
          self._num_random_features, self._key_dim,
          tf.constant([self._seed, self._seed + 1]))

  def _compute_attention(self,
                         query,
                         key,
                         value,
                         feature_transform,
                         is_short_seq,
                         attention_mask=None,
                         training=False,
                         numeric_stabler=_NUMERIC_STABLER):
    """Applies kernel attention with query, key, value tensors.

    This function defines the computation inside `call` with projected
    multi-head Q, K, V inputs. Users can override this function for customized
    attention implementation.

    Args:
      query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
      key: Projected key `Tensor` of shape `[B, S, N, key_dim]`.
      value: Projected value `Tensor` of shape `[B, S, N, value_dim]`.
      feature_transform: A non-linear transform of the keys and quries.
      is_short_seq: boolean predicate indicating whether input data consists of
        short or long sequences; usually short sequence is defined as having
        length L <= 1024.
Jialu Liu's avatar
Jialu Liu committed
386
387
388
      attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
        to masked positions. Note that the mask is only appied to the keys. User
        may want to mask the output if query contains pads.
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
      numeric_stabler: A scalar value added to avoid divide by 0.

    Returns:
      attention_output: Multi-headed outputs of attention computation.
    """
    projection_matrix = None
    if self._num_random_features > 0:
      if self._redraw and training:
        projection_matrix = create_projection_matrix(self._num_random_features,
                                                     self._key_dim)
      else:
        projection_matrix = self._projection_matrix

Jialu Liu's avatar
Jialu Liu committed
404
405
406
407
408
409
    if self._scale_by_length:
      scale = tf.math.log(tf.reduce_sum(attention_mask,
                                        axis=-1)) * self._scale / math.log(512)
      scale = tf.reshape(scale, [-1, 1, 1, 1])
    else:
      scale = self._scale
410
411
412
413
    if is_short_seq:
      # Note: Applying scalar multiply at the smaller end of einsum improves
      # XLA performance, but may introduce slight numeric differences in
      # the Transformer attention head.
Jialu Liu's avatar
Jialu Liu committed
414
      query = query * scale
415
416
417
418
    else:
      # Note: we suspect spliting the scale to key, query yields smaller
      # approximation variance when random projection is used.
      # For simplicity, we also split when there's no random projection.
Jialu Liu's avatar
Jialu Liu committed
419
420
      key *= tf.math.sqrt(scale)
      query *= tf.math.sqrt(scale)
Frederick Liu's avatar
Frederick Liu committed
421

422
423
424
425
426
427
    if feature_transform != "expplus":
      key_prime = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
      query_prime = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
    else:
      key_prime = expplus(key, query, False, projection_matrix)
      query_prime = expplus(query, key, True, projection_matrix)
428
429

    if attention_mask is not None:
430
      key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
431

432
    if is_short_seq:
433
      attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
434
435
436
      attention_scores = tf.nn.softmax(attention_scores, axis=2)
      attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
    else:
437
      kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
438
      denominator = 1.0 / (
439
440
441
442
          tf.einsum("BTNH,BNH->BTN", query_prime,
                    tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER)
      attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv,
                                   denominator)
443
    return attention_output
444
445

  def _build_from_signature(self, query, value, key=None):
Rebecca Chen's avatar
Rebecca Chen committed
446
    super()._build_from_signature(query=query, value=value, key=key)  # pytype: disable=attribute-error  # typed-keras
447
448
449
450
451
452
453
454
455
456
    if self._begin_kernel > 0:
      common_kwargs = dict(
          kernel_initializer=self._kernel_initializer,
          bias_initializer=self._bias_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activity_regularizer=self._activity_regularizer,
          kernel_constraint=self._kernel_constraint,
          bias_constraint=self._bias_constraint)
      self._output_dense_softmax = self._make_output_dense(
457
458
          self._query_shape.rank - 1,
          common_kwargs,
459
460
461
          name="attention_output_softmax")
      self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)

462
  def call(self, query, value, key=None, attention_mask=None, training=False):
Frederick Liu's avatar
Frederick Liu committed
463
464
465
466
467
468
469
    """Compute attention with kernel mechanism.

    Args:
      query: Query `Tensor` of shape `[B, T, dim]`.
      value: Value `Tensor` of shape `[B, S, dim]`.
      key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
        `value` for both `key` and `value`, which is the most common case.
Jialu Liu's avatar
Jialu Liu committed
470
471
472
      attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
        to masked positions. Note that the mask is only appied to the keys. User
        may want to mask the output if query contains pads.
Frederick Liu's avatar
Frederick Liu committed
473
474
475
476
477
478
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).

    Returns:
      Multi-headed outputs of attention computation.
    """
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value

    #   N = `num_attention_heads`
    #   H = `size_per_head`
    # `query` = [B, T, N ,H]
    query = self._query_dense(query)

    # `key` = [B, S, N, H]
    key = self._key_dense(key)

    # `value` = [B, S, N, D]
    value = self._value_dense(value)

    if self._begin_kernel > 0:
      attention_output_softmax = self._compute_attention(
497
498
          query[:, :self._begin_kernel], key, value, "identity", True,
          attention_mask, training)
499
500
501
502
503
      attention_output_softmax = self._dropout_softmax(attention_output_softmax)
      attention_output_softmax = self._output_dense_softmax(
          attention_output_softmax)

      attention_output_kernel = self._compute_attention(
504
505
          query[:, self._begin_kernel:], key, value, self._feature_transform,
          self._is_short_seq, attention_mask, training)
506
      attention_output_kernel = self._dropout_layer(attention_output_kernel)
507
      attention_output_kernel = self._output_dense(attention_output_kernel)
508
509
510
      attention_output = tf.concat(
          [attention_output_softmax, attention_output_kernel], axis=1)
    else:
Jialu Liu's avatar
Jialu Liu committed
511
512
513
514
      attention_output = self._compute_attention(query, key, value,
                                                 self._feature_transform,
                                                 self._is_short_seq,
                                                 attention_mask, training)
515
516
517
518
519
520
521
522
523
524
525
526
527
528
      # This is actually dropping out entire tokens to attend to, which might
      # seem a bit unusual, but is taken from the original Transformer paper.
      attention_output = self._dropout_layer(attention_output)
      attention_output = self._output_dense(attention_output)
    return attention_output

  def get_config(self):
    config = {
        "feature_transform": self._feature_transform,
        "num_random_features": self._num_random_features,
        "seed": self._seed,
        "redraw": self._redraw,
        "is_short_seq": self._is_short_seq,
        "begin_kernel": self._begin_kernel,
Frederick Liu's avatar
Frederick Liu committed
529
        "scale": self._scale,
530
531
532
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))