Commit 26565d0d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Transformer Encoder: when embedding width differs from hidden size, add a...

Transformer Encoder: when embedding width differs from hidden size, add a projection to hidden size.

PiperOrigin-RevId: 312708922
parent 5a68ac62
......@@ -27,14 +27,12 @@ class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
embedding_size,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
"""Constructs AlbertConfig.
Args:
embedding_size: Size of the factorized word embeddings.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
......@@ -43,7 +41,6 @@ class AlbertConfig(configs.BertConfig):
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
self.embedding_size = embedding_size
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder
......@@ -55,7 +52,7 @@ class AlbertConfig(configs.BertConfig):
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(embedding_size=None, vocab_size=None)
config = AlbertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
......@@ -160,10 +160,10 @@ def get_transformer_encoder(bert_config,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs)
else:
assert isinstance(bert_config, configs.BertConfig)
......
......@@ -39,6 +39,7 @@ class BertConfig(object):
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
embedding_size=None,
backward_compatible=True):
"""Constructs BertConfig.
......@@ -63,6 +64,7 @@ class BertConfig(object):
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: (Optional) width of the factorized word embeddings.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
......@@ -77,6 +79,7 @@ class BertConfig(object):
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.backward_compatible = backward_compatible
@classmethod
......
......@@ -42,9 +42,9 @@ class AlbertTransformerEncoder(tf.keras.Model):
Arguments:
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into
two matrices in the shape of ['vocab_size', 'embedding_width'] and
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
hidden_size: The size of the transformer hidden layers.
......@@ -110,6 +110,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
......@@ -141,13 +143,14 @@ class AlbertTransformerEncoder(tf.keras.Model):
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embeddings = layers.DenseEinsum(
embeddings = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')(
embeddings)
......@@ -176,9 +179,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
first_token_tensor)
super(AlbertTransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids],
outputs=[data, cls_output],
**kwargs)
inputs=[word_ids, mask, type_ids], outputs=[data, cls_output], **kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
......
......@@ -64,6 +64,11 @@ class TransformerEncoder(tf.keras.Model):
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full
output.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into
two matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def __init__(self,
......@@ -81,6 +86,7 @@ class TransformerEncoder(tf.keras.Model):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
return_all_encoder_outputs=False,
output_range=None,
embedding_width=None,
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
......@@ -103,6 +109,7 @@ class TransformerEncoder(tf.keras.Model):
'initializer': tf.keras.initializers.serialize(initializer),
'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range,
'embedding_width': embedding_width,
}
word_ids = tf.keras.layers.Input(
......@@ -112,9 +119,11 @@ class TransformerEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
......@@ -126,17 +135,27 @@ class TransformerEncoder(tf.keras.Model):
max_sequence_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=hidden_size,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')(type_ids))
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
......
......@@ -173,6 +173,21 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len)
# Creates a TransformerEncoder with embedding_width != hidden_size
test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size,
hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
embedding_width=16)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options.
......@@ -190,7 +205,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
attention_dropout_rate=0.22,
initializer="glorot_uniform",
return_all_encoder_outputs=False,
output_range=-1)
output_range=-1,
embedding_width=16)
network = transformer_encoder.TransformerEncoder(**kwargs)
expected_config = dict(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment