Commit 6d9256ce authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds return_all_encoder_outputs option to TransformerEncoder to build a...

Adds return_all_encoder_outputs option to TransformerEncoder to build a network returning all encoder layers outputs. This is useful for Transformer Seq2Seq.

PiperOrigin-RevId: 293395112
parent 0015eedf
...@@ -60,6 +60,8 @@ class TransformerEncoder(network.Network): ...@@ -60,6 +60,8 @@ class TransformerEncoder(network.Network):
within the transformer layers. within the transformer layers.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
float_dtype: The dtype of this encoder. Can be 'float32' or 'float16'. float_dtype: The dtype of this encoder. Can be 'float32' or 'float16'.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
""" """
def __init__(self, def __init__(self,
...@@ -76,6 +78,7 @@ class TransformerEncoder(network.Network): ...@@ -76,6 +78,7 @@ class TransformerEncoder(network.Network):
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
float_dtype='float32', float_dtype='float32',
return_all_encoder_outputs=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -97,6 +100,7 @@ class TransformerEncoder(network.Network): ...@@ -97,6 +100,7 @@ class TransformerEncoder(network.Network):
'attention_dropout_rate': attention_dropout_rate, 'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'float_dtype': float_dtype, 'float_dtype': float_dtype,
'return_all_encoder_outputs': return_all_encoder_outputs,
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -146,6 +150,7 @@ class TransformerEncoder(network.Network): ...@@ -146,6 +150,7 @@ class TransformerEncoder(network.Network):
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask]) attention_mask = layers.SelfAttentionMask()([data, mask])
encoder_outputs = []
for i in range(num_layers): for i in range(num_layers):
layer = layers.Transformer( layer = layers.Transformer(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -158,10 +163,11 @@ class TransformerEncoder(network.Network): ...@@ -158,10 +163,11 @@ class TransformerEncoder(network.Network):
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
encoder_outputs.append(data)
first_token_tensor = ( first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data) tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
) encoder_outputs[-1]))
cls_output = tf.keras.layers.Dense( cls_output = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
...@@ -169,10 +175,13 @@ class TransformerEncoder(network.Network): ...@@ -169,10 +175,13 @@ class TransformerEncoder(network.Network):
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
if return_all_encoder_outputs:
outputs = [encoder_outputs, cls_output]
else:
outputs = [encoder_outputs[-1], cls_output]
super(TransformerEncoder, self).__init__( super(TransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
outputs=[data, cls_output],
**kwargs)
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
......
...@@ -55,6 +55,34 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -55,6 +55,34 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
def test_all_encoder_outputs_network_creation(self):
hidden_size = 32
sequence_length = 21
# Create a small TransformerEncoder for testing.
test_network = transformer_encoder.TransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2,
num_layers=3,
return_all_encoder_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
all_encoder_outputs, pooled = test_network([word_ids, mask, type_ids])
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
...@@ -146,7 +174,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -146,7 +174,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
dropout_rate=0.05, dropout_rate=0.05,
attention_dropout_rate=0.22, attention_dropout_rate=0.22,
initializer="glorot_uniform", initializer="glorot_uniform",
float_dtype="float16") float_dtype="float16",
return_all_encoder_outputs=False)
network = transformer_encoder.TransformerEncoder(**kwargs) network = transformer_encoder.TransformerEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment