Commit 26565d0d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Transformer Encoder: when embedding width differs from hidden size, add a...

Transformer Encoder: when embedding width differs from hidden size, add a projection to hidden size.

PiperOrigin-RevId: 312708922
parent 5a68ac62
...@@ -27,14 +27,12 @@ class AlbertConfig(configs.BertConfig): ...@@ -27,14 +27,12 @@ class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`.""" """Configuration for `ALBERT`."""
def __init__(self, def __init__(self,
embedding_size,
num_hidden_groups=1, num_hidden_groups=1,
inner_group_num=1, inner_group_num=1,
**kwargs): **kwargs):
"""Constructs AlbertConfig. """Constructs AlbertConfig.
Args: Args:
embedding_size: Size of the factorized word embeddings.
num_hidden_groups: Number of group for the hidden layers, parameters in num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT 'inner_group_num' has to be 1 for now, because all released ALBERT
...@@ -43,7 +41,6 @@ class AlbertConfig(configs.BertConfig): ...@@ -43,7 +41,6 @@ class AlbertConfig(configs.BertConfig):
**kwargs: The remaining arguments are the same as above 'BertConfig'. **kwargs: The remaining arguments are the same as above 'BertConfig'.
""" """
super(AlbertConfig, self).__init__(**kwargs) super(AlbertConfig, self).__init__(**kwargs)
self.embedding_size = embedding_size
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1 # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder # in the released ALBERT. Support other values in AlbertTransformerEncoder
...@@ -55,7 +52,7 @@ class AlbertConfig(configs.BertConfig): ...@@ -55,7 +52,7 @@ class AlbertConfig(configs.BertConfig):
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters.""" """Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(embedding_size=None, vocab_size=None) config = AlbertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object): for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value config.__dict__[key] = value
return config return config
...@@ -160,10 +160,10 @@ def get_transformer_encoder(bert_config, ...@@ -160,10 +160,10 @@ def get_transformer_encoder(bert_config,
sequence_length=sequence_length, sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)) stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig): if isinstance(bert_config, albert_configs.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs) return networks.AlbertTransformerEncoder(**kwargs)
else: else:
assert isinstance(bert_config, configs.BertConfig) assert isinstance(bert_config, configs.BertConfig)
......
...@@ -39,6 +39,7 @@ class BertConfig(object): ...@@ -39,6 +39,7 @@ class BertConfig(object):
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
initializer_range=0.02, initializer_range=0.02,
embedding_size=None,
backward_compatible=True): backward_compatible=True):
"""Constructs BertConfig. """Constructs BertConfig.
...@@ -63,6 +64,7 @@ class BertConfig(object): ...@@ -63,6 +64,7 @@ class BertConfig(object):
`BertModel`. `BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
embedding_size: (Optional) width of the factorized word embeddings.
backward_compatible: Boolean, whether the variables shape are compatible backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT. with checkpoints converted from TF 1.x BERT.
""" """
...@@ -77,6 +79,7 @@ class BertConfig(object): ...@@ -77,6 +79,7 @@ class BertConfig(object):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.backward_compatible = backward_compatible self.backward_compatible = backward_compatible
@classmethod @classmethod
......
...@@ -42,9 +42,9 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -42,9 +42,9 @@ class AlbertTransformerEncoder(tf.keras.Model):
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width embedding_width: The width of the word embeddings. If the embedding width is
is not equal to hidden size, embedding parameters will be factorized into not equal to hidden size, embedding parameters will be factorized into two
two matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size'). smaller than 'hidden_size').
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
...@@ -110,6 +110,8 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -110,6 +110,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
...@@ -141,13 +143,14 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -141,13 +143,14 @@ class AlbertTransformerEncoder(tf.keras.Model):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
embeddings = layers.DenseEinsum( embeddings = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer, kernel_initializer=initializer,
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
...@@ -176,9 +179,7 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -176,9 +179,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
first_token_tensor) first_token_tensor)
super(AlbertTransformerEncoder, self).__init__( super(AlbertTransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], inputs=[word_ids, mask, type_ids], outputs=[data, cls_output], **kwargs)
outputs=[data, cls_output],
**kwargs)
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
......
...@@ -64,6 +64,11 @@ class TransformerEncoder(tf.keras.Model): ...@@ -64,6 +64,11 @@ class TransformerEncoder(tf.keras.Model):
target sequence of the last transformer layer. `None` means the entire target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full target sequence will attend to the source sequence, which yeilds the full
output. output.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into
two matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
""" """
def __init__(self, def __init__(self,
...@@ -81,6 +86,7 @@ class TransformerEncoder(tf.keras.Model): ...@@ -81,6 +86,7 @@ class TransformerEncoder(tf.keras.Model):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=None, output_range=None,
embedding_width=None,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -103,6 +109,7 @@ class TransformerEncoder(tf.keras.Model): ...@@ -103,6 +109,7 @@ class TransformerEncoder(tf.keras.Model):
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'return_all_encoder_outputs': return_all_encoder_outputs, 'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range, 'output_range': output_range,
'embedding_width': embedding_width,
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -112,9 +119,11 @@ class TransformerEncoder(tf.keras.Model): ...@@ -112,9 +119,11 @@ class TransformerEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=hidden_size, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
name='word_embeddings') name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
...@@ -126,17 +135,27 @@ class TransformerEncoder(tf.keras.Model): ...@@ -126,17 +135,27 @@ class TransformerEncoder(tf.keras.Model):
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding(
type_embeddings = ( vocab_size=type_vocab_size,
layers.OnDeviceEmbedding( embedding_width=embedding_width,
vocab_size=type_vocab_size, initializer=initializer,
embedding_width=hidden_size, use_one_hot=True,
initializer=initializer, name='type_embeddings')
use_one_hot=True, type_embeddings = self._type_embedding_layer(type_ids)
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
embeddings = ( embeddings = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', name='embeddings/layer_norm',
......
...@@ -173,6 +173,21 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -173,6 +173,21 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
outputs = model.predict([word_id_data, mask_data, type_id_data]) outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len) self.assertEqual(outputs[0].shape[1], out_seq_len)
# Creates a TransformerEncoder with embedding_width != hidden_size
test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
embedding_width=16)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
...@@ -190,7 +205,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -190,7 +205,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
attention_dropout_rate=0.22, attention_dropout_rate=0.22,
initializer="glorot_uniform", initializer="glorot_uniform",
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=-1) output_range=-1,
embedding_width=16)
network = transformer_encoder.TransformerEncoder(**kwargs) network = transformer_encoder.TransformerEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment