Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,6 +19,7 @@ import tempfile ...@@ -19,6 +19,7 @@ import tempfile
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers from official.nlp.modeling.layers import text_layers
...@@ -120,10 +121,10 @@ class BertTokenizerTest(tf.test.TestCase): ...@@ -120,10 +121,10 @@ class BertTokenizerTest(tf.test.TestCase):
def model_fn(features, labels, mode): def model_fn(features, labels, mode):
del labels # Unused. del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode, return tf_estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"]) predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn) estimator = tf_estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn)) outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
[2, 4, 5, 3]])) [2, 4, 5, 3]]))
...@@ -231,10 +232,10 @@ class SentencepieceTokenizerTest(tf.test.TestCase): ...@@ -231,10 +232,10 @@ class SentencepieceTokenizerTest(tf.test.TestCase):
def model_fn(features, labels, mode): def model_fn(features, labels, mode):
del labels # Unused. del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode, return tf_estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"]) predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn) estimator = tf_estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn)) outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 8, 3, 0], self.assertAllEqual(outputs, np.array([[2, 8, 3, 0],
[2, 12, 3, 0]])) [2, 12, 3, 0]]))
...@@ -537,10 +538,10 @@ class FastWordPieceBertTokenizerTest(tf.test.TestCase): ...@@ -537,10 +538,10 @@ class FastWordPieceBertTokenizerTest(tf.test.TestCase):
def model_fn(features, labels, mode): def model_fn(features, labels, mode):
del labels # Unused. del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode, return tf_estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"]) predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn) estimator = tf_estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn)) outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
[2, 4, 5, 3]])) [2, 4, 5, 3]]))
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
from typing import List, Optional, Text, Any, Dict from typing import List, Optional, Text, Any, Dict
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
Layer = tf.keras.layers.Layer Layer = tf.keras.layers.Layer
activations = tf.keras.activations activations = tf.keras.activations
initializers = tf.keras.initializers initializers = tf.keras.initializers
...@@ -64,7 +66,7 @@ class TNExpandCondense(Layer): ...@@ -64,7 +66,7 @@ class TNExpandCondense(Layer):
if 'input_shape' not in kwargs and 'input_dim' in kwargs: if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),) kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(TNExpandCondense, self).__init__(**kwargs) super().__init__(**kwargs)
assert proj_multiplier in [ assert proj_multiplier in [
2, 4, 6, 8, 10, 12 2, 4, 6, 8, 10, 12
...@@ -84,7 +86,7 @@ class TNExpandCondense(Layer): ...@@ -84,7 +86,7 @@ class TNExpandCondense(Layer):
'The last dimension of the inputs to `TNExpandCondense` ' 'The last dimension of the inputs to `TNExpandCondense` '
'should be defined. Found `None`.') 'should be defined. Found `None`.')
super(TNExpandCondense, self).build(input_shape) super().build(input_shape)
self.proj_size = self.proj_multiplier * input_shape[-1] self.proj_size = self.proj_multiplier * input_shape[-1]
...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer): ...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer):
name='w1', name='w1',
shape=(input_shape[-1], input_shape[-1]), shape=(input_shape[-1], input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w2 = self.add_weight( self.w2 = self.add_weight(
name='w2', name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))), shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w3 = self.add_weight( self.w3 = self.add_weight(
name='w3', name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128), shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w4 = self.add_weight( self.w4 = self.add_weight(
name='w4', name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]), shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
if self.use_bias: if self.use_bias:
self.bias = self.add_weight( self.bias = self.add_weight(
...@@ -176,5 +178,5 @@ class TNExpandCondense(Layer): ...@@ -176,5 +178,5 @@ class TNExpandCondense(Layer):
getattr(self, initializer_arg)) getattr(self, initializer_arg))
# Get base config # Get base config
base_config = super(TNExpandCondense, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,8 +19,6 @@ import os ...@@ -19,8 +19,6 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
...@@ -45,13 +43,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -45,13 +43,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2)) @parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple): def test_keras_layer(self, input_dim, proj_multiple):
self.skipTest('Disable the test for now since it imports '
'keras.testing_utils, will reenable this test after we '
'fix the b/184578869')
# TODO(scottzhu): Reenable after fix b/184578869
data = np.random.normal(size=(100, input_dim)) data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32) data = data.astype(np.float32)
layer_test( tf.keras.__internal__.utils.layer_test(
TNExpandCondense, TNExpandCondense,
kwargs={ kwargs={
'proj_multiplier': proj_multiple, 'proj_multiplier': proj_multiple,
...@@ -64,9 +58,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -64,9 +58,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2)) @parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple): def test_train(self, input_dim, proj_multiple):
tf.keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim)) data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple) model = self._build_model(data, proj_multiple)
tf.random.set_seed(0)
model.compile( model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
...@@ -81,7 +75,7 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,7 +75,7 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2)) @parameterized.parameters((768, 6), (1024, 2))
def test_weights_change(self, input_dim, proj_multiple): def test_weights_change(self, input_dim, proj_multiple):
tf.random.set_seed(0) tf.keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim)) data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple) model = self._build_model(data, proj_multiple)
model.compile( model.compile(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
...@@ -77,7 +78,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -77,7 +78,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
intermediate_dropout=0.0, intermediate_dropout=0.0,
attention_initializer=None, attention_initializer=None,
**kwargs): **kwargs):
super(TNTransformerExpandCondense, self).__init__(**kwargs) super().__init__(**kwargs)
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._intermediate_size = intermediate_size
...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -168,7 +170,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -168,7 +170,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
super(TNTransformerExpandCondense, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
...@@ -209,7 +211,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -209,7 +211,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"attention_initializer": "attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer) tf.keras.initializers.serialize(self._attention_initializer)
} }
base_config = super(TNTransformerExpandCondense, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
"""Keras-based transformer block layer.""" """Keras-based transformer block layer."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers import transformer_encoder_block from official.nlp.modeling.layers import transformer_encoder_block
...@@ -31,6 +33,9 @@ class Transformer(transformer_encoder_block.TransformerEncoderBlock): ...@@ -31,6 +33,9 @@ class Transformer(transformer_encoder_block.TransformerEncoderBlock):
This layer implements the Transformer from "Attention Is All You Need". This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762). (https://arxiv.org/abs/1706.03762).
**Warning: this layer is deprecated. Please don't use it. Use the
`TransformerEncoderBlock` layer instead.**
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
...@@ -97,6 +102,8 @@ class Transformer(transformer_encoder_block.TransformerEncoderBlock): ...@@ -97,6 +102,8 @@ class Transformer(transformer_encoder_block.TransformerEncoderBlock):
inner_dropout=intermediate_dropout, inner_dropout=intermediate_dropout,
attention_initializer=attention_initializer, attention_initializer=attention_initializer,
**kwargs) **kwargs)
logging.warning("The `Transformer` layer is deprecated. Please directly "
"use `TransformerEncoderBlock`.")
def get_config(self): def get_config(self):
return { return {
...@@ -226,7 +233,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -226,7 +233,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -244,7 +252,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -244,7 +252,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -256,14 +263,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -256,14 +263,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
key_dim=self.attention_head_size, key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
...@@ -281,7 +291,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -281,7 +291,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -295,22 +307,24 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -295,22 +307,24 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dtype="float32")) dtype="float32"))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout( self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout) rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense( self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
"""Keras-based TransformerEncoder block layer.""" """Keras-based TransformerEncoder block layer."""
from typing import Any, Optional
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -54,9 +56,32 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -54,9 +56,32 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
inner_dropout=0.0, inner_dropout=0.0,
attention_initializer=None, attention_initializer=None,
attention_axes=None, attention_axes=None,
use_query_residual=True,
key_dim=None,
value_dim=None,
output_last_dim=None,
diff_q_kv_att_layer_norm=False,
return_attention_scores=False,
**kwargs): **kwargs):
"""Initializes `TransformerEncoderBlock`. """Initializes `TransformerEncoderBlock`.
Note: If `output_last_dim` is used and `use_query_residual` is `True`, the
`output_last_dim`'s value must equal the first input's last dimension for
the query residual connection to work. This is because the residual
connection after the multi-head-attention requires their dimensions to
match. If `use_query_residual` is `False`, the `output_last_dim` dictactes
the last dimension of the output of this module and the
multi-head-attention.
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
overriden by `output_last_dim`.
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
output dims would be `[batch_size, seq_dim, input_last_dim]`.
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer inner_dim: The output dimension of the first Dense layer in a two-layer
...@@ -88,17 +113,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -88,17 +113,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel. kernel.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
use_query_residual: Toggle to execute residual connection after attention.
key_dim: `key_dim` for the `tf.keras.layers.MultiHeadAttention`. If
`None`, we use the first `input_shape`'s last dim.
value_dim: `value_dim` for the `tf.keras.layers.MultiHeadAttention`.
output_last_dim: Final dimension of the output of this module. This also
dictates the value for the final dimension of the multi-head-attention.
When it's `None`, we use, in order of decreasing precedence, `key_dim` *
`num_heads` or the first `input_shape`'s last dim as the output's last
dim.
diff_q_kv_att_layer_norm: If `True`, create a separate attention layer
norm layer for query and key-value if `norm_first` is `True`. Invalid to
set to `True` if `norm_first` is `False`.
return_attention_scores: If `True`, the output of this layer will be a
tuple and additionally contain the attention scores in the shape of
`[batch_size, num_attention_heads, seq_dim, seq_dim]`.
**kwargs: keyword arguments. **kwargs: keyword arguments.
""" """
util.filter_kwargs(kwargs) util.filter_kwargs(kwargs)
super().__init__(**kwargs) super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is available as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._inner_dim = inner_dim self._inner_dim = inner_dim
self._inner_activation = inner_activation self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout self._output_dropout_rate = output_dropout
self._output_range = output_range self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -112,13 +155,24 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -112,13 +155,24 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._norm_first = norm_first self._norm_first = norm_first
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout self._inner_dropout = inner_dropout
self._use_query_residual = use_query_residual
self._key_dim = key_dim
self._value_dim = value_dim
self._output_last_dim = output_last_dim
self._diff_q_kv_att_layer_norm = diff_q_kv_att_layer_norm
self._return_attention_scores = return_attention_scores
if attention_initializer: if attention_initializer:
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
if self._diff_q_kv_att_layer_norm and not self._norm_first:
raise ValueError("Setting `diff_q_and_kv_attention_layer_norm` to True"
"when `norm_first` is False is invalid.")
def build(self, input_shape): def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape): if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape input_tensor_shape = input_shape
...@@ -133,27 +187,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -133,27 +187,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation = "...bc,cd->...bd" einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1] hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( logging.warning(
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)", hidden_size, self._num_heads)
self._attention_head_size = int(hidden_size // self._num_heads) if self._key_dim is None:
self._key_dim = int(hidden_size // self._num_heads)
if self._output_last_dim is None:
last_output_shape = hidden_size
else:
last_output_shape = self._output_last_dim
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention( self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads, num_heads=self._num_heads,
key_dim=self._attention_head_size, key_dim=self._key_dim,
dropout=self._attention_dropout, value_dim=self._value_dim,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
output_shape=self._output_last_dim,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = ( self._attention_layer_norm = (
...@@ -162,11 +224,21 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -162,11 +224,21 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._attention_layer_norm_kv = self._attention_layer_norm
if self._diff_q_kv_att_layer_norm:
self._attention_layer_norm_kv = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm_kv",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -179,14 +251,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -179,14 +251,16 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, last_output_shape),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(
rate=self._output_dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", name="output_layer_norm",
...@@ -194,7 +268,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -194,7 +268,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
super(TransformerEncoderBlock, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
...@@ -234,22 +308,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -234,22 +308,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout, self._inner_dropout,
"attention_initializer": "attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer), tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes, "attention_axes":
self._attention_axes,
"use_query_residual":
self._use_query_residual,
"key_dim":
self._key_dim,
"value_dim":
self._value_dim,
"output_last_dim":
self._output_last_dim,
"diff_q_kv_att_layer_norm":
self._diff_q_kv_att_layer_norm,
} }
base_config = super(TransformerEncoderBlock, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
"""Transformer self-attention encoder block call. """Transformer self-attention encoder block call.
Args: Args:
inputs: a single tensor or a list of tensors. inputs: a single tensor or a list of tensors. `input tensor` as the single
`input tensor` as the single sequence of embeddings. sequence of embeddings. [`input tensor`, `attention mask`] to have the
[`input tensor`, `attention mask`] to have the additional attention additional attention mask. [`query tensor`, `key value tensor`,
mask. `attention mask`] to have separate input streams for the query, and
[`query tensor`, `key value tensor`, `attention mask`] to have separate key/value to the multi-head attention.
input streams for the query, and key/value to the multi-head output_range: the sequence output range, [0, output_range) for slicing the
attention. target sequence. `None` means the target sequence is not sliced. If you
would like to have no change to the model training, it is better to only
set the `output_range` for serving.
Returns: Returns:
An output tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
...@@ -266,33 +353,50 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -266,33 +353,50 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range: if output_range is None:
output_range = self._output_range
if output_range:
if self._norm_first: if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :] source_tensor = input_tensor[:, 0:output_range, :]
input_tensor = self._attention_layer_norm(input_tensor) input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None: if key_value is not None:
key_value = self._attention_layer_norm(key_value) key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:output_range, :]
else: else:
if self._norm_first: if self._norm_first:
source_tensor = input_tensor source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor) input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None: if key_value is not None:
key_value = self._attention_layer_norm(key_value) key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor target_tensor = input_tensor
if key_value is None: if key_value is None:
key_value = input_tensor key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask) if self._return_attention_scores:
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
return_attention_scores=True)
else:
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
attention_output = source_tensor + attention_output # Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if self._use_query_residual:
attention_output = source_tensor + attention_output
else: else:
attention_output = self._attention_layer_norm(target_tensor + if self._use_query_residual:
attention_output) attention_output = target_tensor + attention_output
attention_output = self._attention_layer_norm(attention_output)
if self._norm_first: if self._norm_first:
source_attention_output = attention_output source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output) attention_output = self._output_layer_norm(attention_output)
...@@ -303,9 +407,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -303,9 +407,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
if self._norm_first: if self._norm_first:
return source_attention_output + layer_output layer_output = source_attention_output + layer_output
else:
# During mixed precision training, layer norm output is always fp32 for
# now. Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
# During mixed precision training, layer norm output is always fp32 for now. if self._return_attention_scores:
# Casts fp32 for the subsequent add. return layer_output, attention_scores
layer_output = tf.cast(layer_output, tf.float32) else:
return self._output_layer_norm(layer_output + attention_output) return layer_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,8 +23,7 @@ from official.nlp.modeling.layers.transformer_encoder_block import TransformerEn ...@@ -23,8 +23,7 @@ from official.nlp.modeling.layers.transformer_encoder_block import TransformerEn
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters( @parameterized.named_parameters(('base', TransformerEncoderBlock))
('base', TransformerEncoderBlock))
class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
...@@ -117,18 +116,22 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -117,18 +116,22 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
new_layer = transformer_cls( new_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu')
output_range=1) _ = new_layer([input_data, mask_data], output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data], output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_layer_output_range_without_mask(self, transformer_cls): def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, num_attention_heads=10,
inner_activation='relu', norm_first=True) inner_dim=2048,
inner_activation='relu',
norm_first=True)
sequence_length = 21 sequence_length = 21
width = 80 width = 80
...@@ -143,18 +146,19 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -143,18 +146,19 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu',
output_range=1,
norm_first=True) norm_first=True)
_ = new_layer(input_data) _ = new_layer(input_data, output_range=1)
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer(input_data) new_output_tensor = new_layer(input_data, output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_output_range_with_pre_norm(self, transformer_cls): def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, num_attention_heads=10,
inner_activation='relu', norm_first=True) inner_dim=2048,
inner_activation='relu',
norm_first=True)
sequence_length = 21 sequence_length = 21
width = 80 width = 80
...@@ -171,14 +175,16 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -171,14 +175,16 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu',
output_range=1,
norm_first=True) norm_first=True)
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data], output_range=1)
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data], output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -252,6 +258,155 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -252,6 +258,155 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self.assertEqual(output.shape, q_tensor.shape) self.assertEqual(output.shape, q_tensor.shape)
@keras_parameterized.run_all_keras_modes
class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase
):
def tearDown(self):
super(TransformerEncoderBlockLayerTestWithoutParams, self).tearDown()
tf.keras.mixed_precision.set_global_policy('float32')
def test_raises_invalid_arg_error_when_q_kv_dims_are_different(self):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
norm_first=True)
# Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask]
with self.assertRaises(tf.errors.InvalidArgumentError):
test_layer(inputs)
@parameterized.named_parameters(('output_range_not_none', 2),
('output_range_none', None))
def test_needs_diff_q_kv_att_layer_norm_to_be_true_for_diff_q_and_kv_dims(
self, output_range):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
norm_first=True)
# Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask]
with self.assertRaises(tf.errors.InvalidArgumentError):
test_layer(inputs, output_range=output_range)
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
diff_q_kv_att_layer_norm=True,
norm_first=True)
# Forward path.
test_layer(inputs)
@parameterized.named_parameters(('norm_first_is_true', True),
('norm_first_is_false', False))
def test_use_query_residual_false_removes_add_op(self, norm_first):
graph_with_res = tf.Graph()
with graph_with_res.as_default():
layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
norm_first=norm_first)
inputs = tf.keras.Input(shape=(None, None, 2))
outputs = layer(inputs)
tf.keras.Model(inputs=inputs, outputs=outputs)
graph_without_res = tf.Graph()
with graph_without_res.as_default():
layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
norm_first=norm_first,
use_query_residual=False)
inputs = tf.keras.Input(shape=(None, None, 2))
outputs = layer(inputs)
tf.keras.Model(inputs=inputs, outputs=outputs)
graph_with_res_names = {x.name for x in graph_with_res.get_operations()}
graph_without_res_names = {
x.name for x in graph_without_res.get_operations()
}
self.assertIn('transformer_encoder_block/add',
list(graph_with_res_names - graph_without_res_names)[0])
self.assertEmpty(graph_without_res_names - graph_with_res_names)
@parameterized.named_parameters(('key_dim_is_none', None, 128, 2, 128 // 2),
('key_dim_is_not_none', 30, 128, 2, 30))
def test_key_dim(self, key_dim, q_tensor_last_dim, some_num_attention_heads,
expected):
some_inner_dim = 32
some_inner_activation = 'relu'
test_layer = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
key_dim=key_dim)
q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
test_layer([q_tensor, kv_tensor, dummy_mask])
self.assertEqual(expected,
test_layer._attention_layer.get_config()['key_dim'])
@parameterized.named_parameters(
('output_last_dim_is_none_use_query_residual_false', False, None, 128,
128),
('output_last_dim_is_none_use_query_residual_true', True, None, 128, 128),
('output_last_dim_is_not_none', False, 30, 128, 30))
def test_output_last_dim(self, use_query_residual, output_last_dim,
q_tensor_last_dim, expected):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
test_layer = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
# Must be false for multi-head output to be different from
# first input's last dim
use_query_residual=use_query_residual,
output_last_dim=output_last_dim)
q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
output = test_layer([q_tensor, kv_tensor, dummy_mask])
self.assertEqual(output.numpy().shape[-1], expected)
@parameterized.named_parameters(('value_dim_is_none', None, 128, 2, 128 // 2),
('value_dim_is_not_none', 30, 128, 2, 30))
def test_value_dim(self, value_dim, q_tensor_last_dim,
some_num_attention_heads, expected):
some_inner_dim = 32
some_inner_activation = 'relu'
test_layer = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
value_dim=value_dim)
q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
test_layer([q_tensor, kv_tensor, dummy_mask])
self.assertEqual(expected,
test_layer._attention_layer.get_config()['value_dim'])
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase): class TransformerArgumentTest(keras_parameterized.TestCase):
...@@ -277,6 +432,138 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -277,6 +432,138 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
output = encoder_block(inputs) output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size)) self.assertEqual(output.shape, (2, 4, hidden_size))
def test_norm_first_false_and_diff_q_kv_att_layer_norm_true_raises(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
with self.assertRaises(ValueError):
TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
norm_first=False,
diff_q_kv_att_layer_norm=True)
def test_diff_q_kv_att_layer_norm_is_part_of_config_1(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
norm_first=False)
self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config())
self.assertFalse(encoder.get_config()['diff_q_kv_att_layer_norm'])
def test_diff_q_kv_att_layer_norm_is_part_of_config_2(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
norm_first=True,
diff_q_kv_att_layer_norm=True)
self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config())
self.assertTrue(encoder.get_config()['diff_q_kv_att_layer_norm'])
def test_use_query_residual_is_part_of_config_1(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation)
self.assertIn('use_query_residual', encoder.get_config())
self.assertTrue(encoder.get_config()['use_query_residual'])
def test_use_query_residual_is_part_of_config_2(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
use_query_residual=False)
self.assertIn('use_query_residual', encoder.get_config())
self.assertFalse(encoder.get_config()['use_query_residual'])
def test_key_dim_is_part_of_config_1(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation)
self.assertIn('key_dim', encoder.get_config())
self.assertIsNone(encoder.get_config()['key_dim'])
def test_key_dim_is_part_of_config_2(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
key_dim = 10
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
key_dim=key_dim)
self.assertIn('key_dim', encoder.get_config())
self.assertEqual(key_dim, encoder.get_config()['key_dim'])
def test_value_dim_is_part_of_config_1(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation)
self.assertIn('value_dim', encoder.get_config())
self.assertIsNone(encoder.get_config()['value_dim'])
def test_value_dim_is_part_of_config_2(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
value_dim = 10
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
value_dim=value_dim)
self.assertIn('value_dim', encoder.get_config())
self.assertEqual(value_dim, encoder.get_config()['value_dim'])
def test_output_last_dim_is_part_of_config_1(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation)
self.assertIn('output_last_dim', encoder.get_config())
self.assertIsNone(encoder.get_config()['output_last_dim'])
def test_output_last_dim_is_part_of_config_2(self):
some_num_attention_heads = 2
some_inner_dim = 32
some_inner_activation = 'relu'
output_last_dim = 10
encoder = TransformerEncoderBlock(
num_attention_heads=some_num_attention_heads,
inner_dim=some_inner_dim,
inner_activation=some_inner_activation,
output_last_dim=output_last_dim)
self.assertIn('output_last_dim', encoder.get_config())
self.assertEqual(output_last_dim, encoder.get_config()['output_last_dim'])
def test_get_config(self): def test_get_config(self):
num_attention_heads = 2 num_attention_heads = 2
encoder_block = TransformerEncoderBlock( encoder_block = TransformerEncoderBlock(
...@@ -290,7 +577,12 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -290,7 +577,12 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
norm_epsilon=1e-6, norm_epsilon=1e-6,
inner_dropout=0.1, inner_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform( attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.)) minval=0., maxval=1.),
use_query_residual=False,
key_dim=20,
value_dim=30,
output_last_dim=40,
diff_q_kv_att_layer_norm=True)
encoder_block_config = encoder_block.get_config() encoder_block_config = encoder_block.get_config()
new_encoder_block = TransformerEncoderBlock.from_config( new_encoder_block = TransformerEncoderBlock.from_config(
encoder_block_config) encoder_block_config)
...@@ -319,6 +611,88 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -319,6 +611,88 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
@parameterized.parameters(
{
'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.3
}, {
'output_dropout': 0.0,
'attention_dropout': 0.2,
'inner_dropout': 0.3
}, {
'output_dropout': 0.1,
'attention_dropout': 0.0,
'inner_dropout': 0.3
}, {
'output_dropout': 0.1,
'attention_dropout': 0.2,
'inner_dropout': 0.0
})
def test_dropout_config(self, output_dropout, attention_dropout,
inner_dropout):
test_layer = TransformerEncoderBlock(
num_attention_heads=2,
inner_dim=32,
inner_activation='relu',
output_dropout=output_dropout,
attention_dropout=attention_dropout,
inner_dropout=inner_dropout)
seq_len = 21
hidden_size = 512
input_tensor = tf.keras.Input(shape=(seq_len, hidden_size))
_ = test_layer(input_tensor)
true_output_dropout = test_layer._output_dropout.get_config()['rate']
true_attention_dropout = test_layer._attention_dropout.get_config()['rate']
true_inner_dropout = test_layer._inner_dropout_layer.get_config()['rate']
self.assertEqual(true_output_dropout, output_dropout)
self.assertEqual(true_attention_dropout, attention_dropout)
self.assertEqual(true_inner_dropout, inner_dropout)
@parameterized.named_parameters(
(
'return_attention_scores_is_false',
False,
),
(
'return_attention_scores_is_true',
True,
),
)
def test_return_attention_scores(self, return_attention_scores):
num_attention_heads = 7
sequence_length = 21
width = 80
test_layer = TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=2048,
inner_activation='relu',
return_attention_scores=return_attention_scores)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
expected_layer_output_shape = [None, sequence_length, width]
expected_attention_scores_shape = [
None, num_attention_heads, sequence_length, sequence_length
]
if return_attention_scores:
self.assertIsInstance(output_tensor, tuple)
self.assertEqual(len(output_tensor), 2)
# First is the standard output.
self.assertEqual(output_tensor[0].shape.as_list(),
expected_layer_output_shape)
# Second is the attention scores.
self.assertEqual(output_tensor[1].shape.as_list(),
expected_attention_scores_shape)
else:
# Only the standard layer output.
self.assertEqual(output_tensor.shape.as_list(),
expected_layer_output_shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,7 +19,9 @@ from absl import logging ...@@ -19,7 +19,9 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -37,8 +39,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -37,8 +39,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. inner_dim: The output dimension of the first Dense layer in a two-layer
intermediate_activation: Activation for the intermediate layer. feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
attention_cls: A class to instantiate attention layer, or a layer instance. attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance or None. If `attention_cls` is a if attention_cls is a layer instance or None. If `attention_cls` is a
...@@ -58,8 +62,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -58,8 +62,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
Ignored if feedforward_cls is a layer instance or is None. If Ignored if feedforward_cls is a layer instance or is None. If
`feedforward_cls` is a class, but `feedforward_cfg` is None, following `feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance: { kwargs will be used to instantiate the feedforward instance: {
"intermediate_size": intermediate_size, "inner_dim": inner_dim,
"intermediate_activation": intermediate_activation, "inner_activation": inner_activation,
"dropout": dropout_rate, "dropout": dropout_rate,
"name": "feedforward" }. "name": "feedforward" }.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
...@@ -75,8 +79,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -75,8 +79,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_attention_heads, num_attention_heads,
intermediate_size, inner_dim=768,
intermediate_activation, inner_activation=tf_utils.get_activation("gelu"),
attention_cls=attention.MultiHeadAttention, attention_cls=attention.MultiHeadAttention,
attention_cfg=None, attention_cfg=None,
feedforward_cls=None, feedforward_cls=None,
...@@ -92,7 +96,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -92,7 +96,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(TransformerScaffold, self).__init__(**kwargs) inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._attention_cfg = attention_cfg self._attention_cfg = attention_cfg
self._attention_cls = attention_cls self._attention_cls = attention_cls
...@@ -100,8 +107,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -100,8 +107,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_cfg = feedforward_cfg self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first self._norm_first = norm_first
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._inner_dim = inner_dim
self._intermediate_activation = intermediate_activation self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -112,9 +119,15 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -112,9 +119,15 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape): def build(self, input_shape):
input_tensor_shape = input_shape[0] if ( if isinstance(input_shape, tf.TensorShape):
len(input_shape) == 2) else input_shape input_tensor_shape = input_shape
input_tensor_shape = tf.TensorShape(input_tensor_shape) elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
if len(input_tensor_shape.as_list()) != 3: if len(input_tensor_shape.as_list()) != 3:
raise ValueError( raise ValueError(
"TransformerScaffold expects a three-dimensional input of " "TransformerScaffold expects a three-dimensional input of "
...@@ -127,8 +140,6 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -127,8 +140,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -145,6 +156,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -145,6 +156,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
return instance_or_cls(**config) return instance_or_cls(**config)
default_attention_cfg = { default_attention_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
"num_heads": self._num_heads, "num_heads": self._num_heads,
"key_dim": self._attention_head_size, "key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate, "dropout": self._attention_dropout_rate,
...@@ -158,8 +172,15 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -158,8 +172,15 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_cls is not None: if self._feedforward_cls is not None:
default_feedforward_cfg = { default_feedforward_cfg = {
"intermediate_size": self._intermediate_size, "kernel_initializer": tf_utils.clone_initializer(
"intermediate_activation": self._intermediate_activation, self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"inner_dim": self._inner_dim,
"inner_activation": self._inner_activation,
# TODO(hongkuny): try to update all ffn block args.
"intermediate_size": self._inner_dim,
"intermediate_activation": self._inner_activation,
"dropout": self._dropout_rate, "dropout": self._dropout_rate,
"name": "feedforward", "name": "feedforward",
} }
...@@ -184,11 +205,14 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -184,11 +205,14 @@ class TransformerScaffold(tf.keras.layers.Layer):
dtype=tf.float32)) dtype=tf.float32))
if self._feedforward_block is None: if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -197,12 +221,15 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -197,12 +221,15 @@ class TransformerScaffold(tf.keras.layers.Layer):
# TODO(b/154538392): Investigate this. # TODO(b/154538392): Investigate this.
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -210,7 +237,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -210,7 +237,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
super(TransformerScaffold, self).build(input_shape) super().build(input_shape)
logging.info("%s configs: %s", self.__class__.__name__, self.get_config()) logging.info("%s configs: %s", self.__class__.__name__, self.get_config())
def get_config(self): def get_config(self):
...@@ -221,10 +248,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -221,10 +248,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_block, self._feedforward_block,
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"intermediate_size": "inner_dim":
self._intermediate_size, self._inner_dim,
"intermediate_activation": "inner_activation":
self._intermediate_activation, self._inner_activation,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"attention_dropout_rate": "attention_dropout_rate":
...@@ -246,21 +273,31 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -246,21 +273,31 @@ class TransformerScaffold(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint)
} }
base_config = super(TransformerScaffold, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: if isinstance(inputs, (list, tuple)):
input_tensor, attention_mask = inputs if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, key_value, attention_mask = (inputs, None, None)
if key_value is None:
key_value = input_tensor
if self._norm_first: if self._norm_first:
source_tensor = input_tensor source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor, training=training) input_tensor = self._attention_layer_norm(input_tensor, training=training)
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask, query=input_tensor, value=key_value, attention_mask=attention_mask,
training=training) training=training)
attention_output = self._attention_dropout(attention_output, attention_output = self._attention_dropout(attention_output,
training=training) training=training)
...@@ -298,7 +335,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -298,7 +335,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
training=training) training=training)
layer_output += source_attention_output layer_output += source_attention_output
else: else:
# if not norm_first, assume that the feedforwad does apply layer norm # Attention: if not norm_first, assume that the feedforwad does apply
# layer norm. The feedford also apply residual connection. Please
# read the `GatedFeedforward` as a concrete example.
layer_output = self._feedforward_block(attention_output, layer_output = self._feedforward_block(attention_output,
training=training) training=training)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -58,7 +58,7 @@ class ValidatedFeedforwardLayer(tf.keras.layers.Layer): ...@@ -58,7 +58,7 @@ class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
hidden_size = input_shape.as_list()[-1] hidden_size = input_shape.as_list()[-1]
self._feedforward_dense = tf.keras.layers.experimental.EinsumDense( self._feedforward_dense = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
...@@ -99,8 +99,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -99,8 +99,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -134,8 +134,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -134,8 +134,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer, feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg, feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -165,8 +165,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -165,8 +165,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -194,8 +194,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -194,8 +194,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -236,8 +236,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -236,8 +236,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
feedforward_cls=feedforward_layer, feedforward_cls=feedforward_layer,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -280,8 +280,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -280,8 +280,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -322,8 +322,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -322,8 +322,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -363,8 +363,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -363,8 +363,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu', inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
...@@ -392,8 +392,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -392,8 +392,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
attention_cls=ValidatedAttentionLayer, attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg, attention_cfg=attention_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, inner_dim=2048,
intermediate_activation='relu') inner_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
...@@ -458,8 +458,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -458,8 +458,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
feedforward_cls=ValidatedFeedforwardLayer, feedforward_cls=ValidatedFeedforwardLayer,
feedforward_cfg=feedforward_layer_cfg, feedforward_cfg=feedforward_layer_cfg,
num_attention_heads=10, num_attention_heads=10,
intermediate_size=None, inner_dim=None,
intermediate_activation=None) inner_activation=None)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width)) data_tensor = tf.keras.Input(shape=(sequence_length, width))
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,6 +18,7 @@ from absl import logging ...@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention from official.nlp.modeling.layers import relative_attention
...@@ -102,7 +103,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -102,7 +103,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
**kwargs): **kwargs):
"""Initializes TransformerXLBlock layer.""" """Initializes TransformerXLBlock layer."""
super(TransformerXLBlock, self).__init__(**kwargs) super().__init__(**kwargs)
self._vocab_size = vocab_size self._vocab_size = vocab_size
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._head_size = head_size self._head_size = head_size
...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
value_dim=self._head_size, value_dim=self._head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=False, use_bias=False,
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn") name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout( self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate) rate=self._attention_dropout_rate)
...@@ -157,30 +158,30 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -157,30 +158,30 @@ class TransformerXLBlock(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
self._inner_dense = tf.keras.layers.experimental.EinsumDense( self._inner_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._inner_size), output_shape=(None, self._inner_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner") name="inner")
self._inner_activation_layer = tf.keras.layers.Activation( self._inner_activation_layer = tf.keras.layers.Activation(
self._inner_activation) self._inner_activation)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer) kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", name="output_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon)
super(TransformerXLBlock, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
...@@ -209,7 +210,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -209,7 +210,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
"inner_dropout": "inner_dropout":
self._inner_dropout, self._inner_dropout,
} }
base_config = super(TransformerXLBlock, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(self,
...@@ -370,7 +371,7 @@ class TransformerXL(tf.keras.layers.Layer): ...@@ -370,7 +371,7 @@ class TransformerXL(tf.keras.layers.Layer):
inner_activation="relu", inner_activation="relu",
**kwargs): **kwargs):
"""Initializes TransformerXL.""" """Initializes TransformerXL."""
super(TransformerXL, self).__init__(**kwargs) super().__init__(**kwargs)
self._vocab_size = vocab_size self._vocab_size = vocab_size
self._initializer = initializer self._initializer = initializer
...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer): ...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer):
"content_attention_bias", "content_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight( self.positional_attention_bias = self.add_weight(
"positional_attention_bias", "positional_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight( self.segment_attention_bias = self.add_weight(
"segment_attention_bias", "segment_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = [] self.transformer_xl_layers = []
for i in range(self._num_layers): for i in range(self._num_layers):
...@@ -460,7 +461,7 @@ class TransformerXL(tf.keras.layers.Layer): ...@@ -460,7 +461,7 @@ class TransformerXL(tf.keras.layers.Layer):
"inner_activation": "inner_activation":
self._inner_activation, self._inner_activation,
} }
base_config = super(TransformerXL, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(self,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment