Commit b5f5bbdd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332992539
parent 83704fde
......@@ -29,7 +29,7 @@ assemble new layers, networks, or models.
described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
* [TransformerDecoderBlock](transformer.py) TransformerDecoderBlock is made up
of self multi-head attention, cross multi-head attention and feedforward
network.
......
......@@ -109,7 +109,7 @@ class CompiledTransformer(Transformer):
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
......@@ -163,7 +163,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
super().__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
......@@ -274,7 +274,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
......@@ -315,7 +315,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super(TransformerDecoderLayer, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
......@@ -329,11 +329,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
......
......@@ -32,12 +32,12 @@ def _create_cache(batch_size, init_decode_length, num_heads, head_size):
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
class TransformerDecoderBlockTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -56,7 +56,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -77,7 +77,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
decoder_block = transformer.TransformerDecoderBlock(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
......@@ -90,7 +90,7 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_initializer=tf.keras.initializers.RandomUniform(
minval=0., maxval=1.))
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
new_decoder_block = transformer.TransformerDecoderBlock.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
......
......@@ -581,7 +581,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
self.decoder_layers = []
for i in range(self.num_layers):
self.decoder_layers.append(
layers.TransformerDecoderLayer(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
......
......@@ -22,7 +22,6 @@ from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer
from official.nlp.transformer import model_utils as transformer_utils
......@@ -59,7 +58,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(
transformer.TransformerDecoderLayer(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment