Commit bddc4930 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Move PackedSequenceEmbedding to model garden

PiperOrigin-RevId: 339454557
parent 5ec73492
# Networks # Networks
Networks are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. It Networks are combinations of layers (and possibly other networks).
They are sub-units of models that would not be trained alone. It
encapsulates common network structures like a classification head encapsulates common network structures like a classification head
or a transformer encoder into an easily handled object with a or a transformer encoder into an easily handled object with a
standardized configuration. standardized configuration.
* [`BertEncoder`](bert_encoder.py) implements a bi-directional * [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep Transformer-based encoder as described in ["BERT: Pre-training of Deep
Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding lookups, Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805).
transformer layers and pooling layer. It includes the embedding lookups, transformer layers and pooling layer.
* [`AlbertEncoder`](albert_encoder.py) implements a * [`AlbertEncoder`](albert_encoder.py) implements a
Transformer-encoder described in the paper ["ALBERT: A Lite BERT for Transformer-encoder described in the paper ["ALBERT: A Lite BERT for
Self-supervised Learning of Language Representations"] Self-supervised Learning of Language Representations"]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters (https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805),
into two smaller matrices and shares parameters across layers. ALBERT refactorizes embedding parameters into two smaller matrices and shares
parameters across layers.
* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the * [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the
MobileBERT network described in the paper ["MobileBERT: a Compact Task-Agnostic MobileBERT network described in the paper ["MobileBERT: a Compact Task-Agnostic
...@@ -24,6 +26,15 @@ BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984). ...@@ -24,6 +26,15 @@ BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984).
intended for use as a classification or regression (if number of classes is set intended for use as a classification or regression (if number of classes is set
to 1) head. to 1) head.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task. * [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an
embedding network that supports packed sequences and position ids.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet: Generalized Autoregressive Pretraining for Language Understanding"(https://arxiv.org/abs/1906.08237). It includes embedding lookups, relative position encodings, mask computations, segment matrix computations and Transformer XL layers using one or two stream relative self-attention. * [`SpanLabeling`](span_labeling.py) implements a single-span labeler
(that is, a prediction head that can predict one start and end index per batch
item) based on a single dense hidden layer. It can be used in the SQuAD task.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet:
Generalized Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). It includes embedding lookups,
relative position encodings, mask computations, segment matrix computations and
Transformer XL layers using one or two stream relative self-attention.
...@@ -18,6 +18,7 @@ from official.nlp.modeling.networks.bert_encoder import BertEncoder ...@@ -18,6 +18,7 @@ from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder
from official.nlp.modeling.networks.packed_sequence_embedding import PackedSequenceEmbedding
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.span_labeling import XLNetSpanLabeling from official.nlp.modeling.networks.span_labeling import XLNetSpanLabeling
from official.nlp.modeling.networks.xlnet_base import XLNetBase from official.nlp.modeling.networks.xlnet_base import XLNetBase
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An embedding network supporting packed sequences and position ids."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class PackedSequenceEmbedding(tf.keras.Model):
"""An embedding network supporting packed sequences and position ids.
This network implements an embedding layer similar to the one described in
"BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding" (https://arxiv.org/abs/1810.04805). On top of it, it supports
to (1) pack multiple sequences into one sequence and (2) allow additional
"position_ids" as input.
Arguments:
vocab_size: The size of the token vocabulary.
type_vocab_size: The size of the type vocabulary.
hidden_size: The hidden size for this encoder.
max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers.
pack_multiple_sequences: If True, we can feed multiple sequences into one
sequence for training and inference (they don't impact each other).
use_position_id: Whether to expect `position_ids` as an input to the
network. If False, the `position_ids` will be inferred: (1) when
pack_multiple_sequences is False, we assume the position ids are 0, 1,
2, ..., seq_length - 1; (2) when pack_multiple_sequences is True, there
may be multiple sub sequences, and for each sub sequence, its position
ids start from 0, 1, 2, ...
"""
def __init__(self,
vocab_size,
type_vocab_size,
hidden_size,
max_seq_length,
initializer,
dropout_rate,
use_position_id=False,
pack_multiple_sequences=False,
**kwargs):
initializer = tf.keras.initializers.get(initializer)
config_dict = {
'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size,
'hidden_size': hidden_size,
'max_seq_length': max_seq_length,
'initializer': tf.keras.initializers.serialize(initializer),
'dropout_rate': dropout_rate,
'use_position_id': use_position_id,
'pack_multiple_sequences': pack_multiple_sequences,
}
word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
inputs = {
'input_word_ids': word_ids,
'input_mask': mask,
'input_type_ids': type_ids,
}
if use_position_id:
position_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='position_ids')
inputs['position_ids'] = position_ids
else:
position_ids = None
if pack_multiple_sequences:
sub_seq_mask = PackedSequenceMask()(word_ids)
else:
sub_seq_mask = None
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=initializer,
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=initializer,
use_dynamic_slicing=True,
max_sequence_length=max_seq_length,
name='position_embedding')
position_embeddings = position_embedding_layer(
word_embeddings, position_ids, sub_seq_mask)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=hidden_size,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)(
embeddings)
embeddings = tf.keras.layers.Dropout(
rate=dropout_rate, dtype=tf.float32)(
embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda(
lambda x: x[0] * tf.cast(x[1], x[0].dtype))(
[attention_mask, sub_seq_mask])
outputs = [embeddings, attention_mask]
super(PackedSequenceEmbedding, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
# TF does not track immutable attrs which do not contain Trackables,
# so by creating a config namedtuple instead of a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._embedding_layer = embedding_layer
self._position_embedding_layer = position_embedding_layer
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_config(self):
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Text')
class PackedSequenceMask(tf.keras.layers.Layer):
"""A layer to create a mask to indicate multiple sub sequences."""
def call(self, input_ids):
"""Implements call() for the layer.
Args:
input_ids: int32 Tensor of shape [batch_size, seq_length].
Returns:
boolean Tensor of shape [batch_size, seq_length, seq_length]. [x, y, z]
is True if for x'th instance in a batch, y'th token and z'th token are
from the same sub sequence.
"""
# Suppose
# - the first token in the parent sequence is [CLS].
# - every sequence starts from [CLS].
# - every sequence only contains one [CLS].
seq_start_token = input_ids[:, 0:1]
seq_start_loc = tf.cast(tf.equal(input_ids, seq_start_token), tf.int32)
# Set different ids for different sub sequences.
seq_ids = tf.expand_dims(tf.cumsum(seq_start_loc, -1), -1)
return tf.equal(seq_ids, tf.transpose(seq_ids, [0, 2, 1]))
@tf.keras.utils.register_keras_serializable(package='Text')
class PositionEmbeddingWithSubSeqMask(tf.keras.layers.Layer):
"""Creates a positional embedding with sub-sequence masking.
This layer creates a positional embedding as described in "BERT: Pre-training
of Deep Bidirectional Transformers for Language Understanding"
(https://arxiv.org/abs/1810.04805). On top of it, it supports
`position_ids` and `sub_sequence_mask` tensors.
This layer can be set up to either create a statically shaped slice or a
dynamically shaped slice. If `use_dynamic_slicing` is True, the input tensor
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
input size must be fixed.
Arguments:
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_dynamic_slicing: Whether to use the dynamic slicing path.
max_sequence_length: The maximum size of the dynamic sequence. Only
applicable if `use_dynamic_slicing` is True.
"""
def __init__(self,
initializer='glorot_uniform',
use_dynamic_slicing=False,
max_sequence_length=None,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
if 'dtype' not in kwargs:
kwargs['dtype'] = 'float32'
super(PositionEmbeddingWithSubSeqMask, self).__init__(**kwargs)
if use_dynamic_slicing and max_sequence_length is None:
raise ValueError(
'If `use_dynamic_slicing` is True, `max_sequence_length` must be set.'
)
self._max_sequence_length = max_sequence_length
self._initializer = tf.keras.initializers.get(initializer)
self._use_dynamic_slicing = use_dynamic_slicing
def get_config(self):
config = {
'max_sequence_length': self._max_sequence_length,
'initializer': tf.keras.initializers.serialize(self._initializer),
'use_dynamic_slicing': self._use_dynamic_slicing,
}
base_config = super(PositionEmbeddingWithSubSeqMask, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Implements build() for the layer."""
dimension_list = input_shape.as_list()
if len(dimension_list) != 3:
raise ValueError('PositionEmbedding expects a 3-dimensional input tensor '
'of shape [batch, sequence, width]')
seq_length = dimension_list[1]
width = dimension_list[2]
# If we are not using dynamic slicing, we must assume that the sequence
# length is fixed and max_sequence_length should not be specified.
if not self._use_dynamic_slicing:
if seq_length is None:
raise ValueError(
'PositionEmbedding must have `use_dynamic_slicing` set '
'to True (and max_sequence_length set) when the '
'sequence (1st) dimension of the input is None.')
if self._max_sequence_length is not None:
raise ValueError(
'When `use_dynamic_slicing` is False, max_sequence_length should '
'not be specified and we ought to use seq_length to get the '
'variable shape.')
if self._max_sequence_length is not None:
weight_sequence_length = self._max_sequence_length
else:
weight_sequence_length = seq_length
self._position_embeddings = self.add_weight(
'embeddings',
shape=[weight_sequence_length, width],
initializer=self._initializer)
super(PositionEmbeddingWithSubSeqMask, self).build(input_shape)
def call(self, inputs, position_ids=None, sub_sequence_mask=None):
"""Implements call() for the layer.
When `position_ids` is specified, it will return the position embeddings
corresponding to this `position_ids`; otherwise, `position_ids` will be
inferred in the following way:
(1) When `sub_sequence_mask` is None, we assume the position ids are
0, 1, 2, ..., seq_length - 1.
(2) When `sub_sequence_mask` is specified, there may be multiple sub
sequences, and for each sub sequence, its position ids start from
0, 1, 2, ...
Args:
inputs: Word embeddings in shape [batch, seq_length, embedding_dim].
position_ids: An optional int32 tensor in shape [batch, seq_length].
sub_sequence_mask: An optional bool tensor in shape [batch, seq_length,
seq_length]. [x, y, z] is True if for x'th instance in a batch, y'th
token and z'th token are from the same sub sequence.
Returns:
The position embeddings in shape [batch, seq_length, embedding_dim].
"""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing:
position_embeddings = self._position_embeddings[:input_shape[1], :]
else:
position_embeddings = self._position_embeddings
if position_ids is not None:
return tf.gather(position_embeddings, position_ids)
if sub_sequence_mask is None:
return tf.broadcast_to(position_embeddings, input_shape)
else:
sub_sequence_mask = tf.cast(sub_sequence_mask, tf.int32)
# For each sub sequence, its position ids start from 0, 1, 2, ...
position_ids = tf.linalg.diag_part(tf.cumsum(sub_sequence_mask, -1)) - 1
return tf.gather(position_embeddings, position_ids)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.modeling.networks.packed_sequence_embedding."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.networks import packed_sequence_embedding
class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
def tearDown(self):
super(PackedSequenceEmbeddingTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
@parameterized.parameters([
(True, True, True),
(False, False, True),
(False, True, False),
(True, False, False),
])
def test_network_creation(self, use_position_id, pack_multiple_sequences,
use_float16):
"""Validate that the Keras object can be created."""
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
seq_length = 16
vocab_size = 100
max_position_embeddings = 32
type_vocab_size = 2
hidden_size = 32
embedding_cfg = dict(
vocab_size=vocab_size,
type_vocab_size=2,
hidden_size=hidden_size,
max_seq_length=max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
dropout_rate=0.1,
use_position_id=use_position_id,
pack_multiple_sequences=pack_multiple_sequences,
)
test_object = packed_sequence_embedding.PackedSequenceEmbedding(
**embedding_cfg)
input_word_ids = tf.keras.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.Input(shape=(seq_length,), dtype=tf.int32)
network_inputs = {
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
}
if use_position_id:
network_inputs['position_ids'] = tf.keras.Input(
shape=(seq_length,), dtype=tf.int32)
embedding, mask = test_object(network_inputs)
# Create a model based off of this network:
model = tf.keras.Model(network_inputs, [embedding, mask])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size = 3
word_id_data = np.random.randint(vocab_size, size=(batch_size, seq_length))
mask_data = np.random.randint(2, size=(batch_size, seq_length))
type_id_data = np.random.randint(
type_vocab_size, size=(batch_size, seq_length))
feed_input = {
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
}
if use_position_id:
feed_input['position_ids'] = np.random.randint(
seq_length, size=(batch_size, seq_length))
embeddings, attention_mask = model.predict(feed_input)
expected_embeddings_shape = [3, seq_length, hidden_size]
expected_attention_mask_shape = [3, seq_length, seq_length]
self.assertAllEqual(expected_embeddings_shape, embeddings.shape)
self.assertAllEqual(expected_attention_mask_shape, attention_mask.shape)
def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
# Create a network object that sets all of its config options.
embedding_cfg = dict(
vocab_size=100,
type_vocab_size=2,
hidden_size=64,
max_seq_length=32,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
dropout_rate=0.1,
use_position_id=True,
pack_multiple_sequences=False,
)
network = packed_sequence_embedding.PackedSequenceEmbedding(**embedding_cfg)
expected_config = dict(embedding_cfg)
expected_config['initializer'] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config['initializer']))
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = packed_sequence_embedding.PackedSequenceEmbedding.from_config(
network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment