Commit 563d923a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Introduce output_range for transformer layer and encoder last layer.

PiperOrigin-RevId: 310032518
parent 9c1429cf
...@@ -6,68 +6,41 @@ The TensorFlow official models are a collection of models ...@@ -6,68 +6,41 @@ The TensorFlow official models are a collection of models
that use TensorFlow’s high-level APIs. that use TensorFlow’s high-level APIs.
They are intended to be well-maintained, tested, and kept up to date They are intended to be well-maintained, tested, and kept up to date
with the latest TensorFlow API. with the latest TensorFlow API.
They should also be reasonably optimized for fast performance while still They should also be reasonably optimized for fast performance while still
being easy to read. being easy to read.
These models are used as end-to-end tests, ensuring that the models run These models are used as end-to-end tests, ensuring that the models run
with the same or improved speed and performance with each new TensorFlow build. with the same or improved speed and performance with each new TensorFlow build.
## More models to come! ## Model Implementations
The team is actively developing new models.
In the near future, we will add:
* State-of-the-art language understanding models:
More members in Transformer family
* Start-of-the-art image classification models:
EfficientNet, MnasNet, and variants
* A set of excellent objection detection models.
## Table of Contents
- [Models and Implementations](#models-and-implementations) ### Natural Language Processing
* [Computer Vision](#computer-vision)
+ [Image Classification](#image-classification)
+ [Object Detection and Segmentation](#object-detection-and-segmentation)
* [Natural Language Processing](#natural-language-processing)
* [Recommendation](#recommendation)
- [How to get started with the official models](#how-to-get-started-with-the-official-models)
## Models and Implementations | Model | Description | Reference |
| ----- | ----------- | --------- |
| [ALBERT](nlp/albert) | A Lite BERT for Self-supervised Learning of Language Representations | [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) |
| [BERT](nlp/bert) | A powerful pre-trained language representation model: BERT (Bidirectional Encoder Representations from Transformers) | [arXiv:1810.04805](https://arxiv.org/abs/1810.04805) |
| [NHNet](nlp/nhnet) | A transformer-based multi-sequence to sequence model: Generating Representative Headlines for News Stories | [arXiv:2001.09386](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [XLNet](nlp/xlnet) | XLNet: Generalized Autoregressive Pretraining for Language Understanding | [arXiv:1906.08237](https://arxiv.org/abs/1906.08237) |
### Computer Vision ### Computer Vision
#### Image Classification | Model | Description | Reference |
| ----- | ----------- | --------- |
| Model | Reference (Paper) | | [MNIST](vision/image_classification) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
|-------|-------------------| | [ResNet](vision/image_classification) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | | [RetinaNet](vision/detection) | A fast and powerful object detector | [arXiv:1708.02002](https://arxiv.org/abs/1708.02002) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) | | [Mask R-CNN](vision/detection) | An object detection and instance segmentation model | [arXiv:1703.06870](https://arxiv.org/abs/1703.06870) |
#### Object Detection and Segmentation ### Other models
| Model | Reference (Paper) | | Model | Description | Reference |
|-------|-------------------| | ----- | ----------- | --------- |
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [NCF](recommendation) | Neural Collaborative Filtering model for recommendation tasks | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
### Natural Language Processing
| Model | Reference (Paper) | ---
|-------|-------------------|
| [ALBERT (A Lite BERT)](nlp/albert) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
| [BERT (Bidirectional Encoder Representations from Transformers)](nlp/bert) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
| [NHNet (News Headline generation model)](nlp/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/transformer) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
| [XLNet](nlp/xlnet) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) |
### Recommendation ## How to get started with the Model Garden official models
| Model | Reference (Paper) |
|-------|-------------------|
| [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
## How to get started with the official models
* The models in the master branch are developed using TensorFlow 2, * The models in the master branch are developed using TensorFlow 2,
and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation) and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
...@@ -135,6 +108,44 @@ os.environ['PYTHONPATH'] += ":/path/to/models" ...@@ -135,6 +108,44 @@ os.environ['PYTHONPATH'] += ":/path/to/models"
pip3 install --user -r official/requirements.txt pip3 install --user -r official/requirements.txt
``` ```
---
## More models to come!
The team is actively developing new models.
In the near future, we will add:
- State-of-the-art language understanding models:
More members in Transformer family
- Start-of-the-art image classification models:
EfficientNet, MnasNet and variants.
- A set of excellent objection detection models.
If you would like to make any fixes or improvements to the models, please
[submit a pull request](https://github.com/tensorflow/models/compare).
---
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](../../../wiki/How-to-contribute). Every model should follow our guidelines to uphold our objectives of readable,
usable, and maintainable code.
### General Guidelines
- Code should be well documented and tested.
- Runnable from a blank environment with ease.
- Trainable on: single GPU/CPU (baseline), multiple GPUs & TPUs
- Compatible with Python 3 (using [six](https://pythonhosted.org/six/)
when being compatible with Python 2 is necessary)
- Conform to
[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
### Implementation Guidelines
These guidelines are to ensure consistent model implementations for
better readability and maintainability.
- Use [common utility functions](utils)
- Export SavedModel at the end of the training.
- Consistent flags and flag-parsing library ([read more here](utils/flags/guidelines.md))
...@@ -101,7 +101,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -101,7 +101,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
@gin.configurable @gin.configurable
def get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config,
sequence_length, sequence_length,
transformer_encoder_cls=None): transformer_encoder_cls=None,
output_range=None):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
...@@ -109,6 +110,8 @@ def get_transformer_encoder(bert_config, ...@@ -109,6 +110,8 @@ def get_transformer_encoder(bert_config,
sequence_length: Maximum sequence length of the training data. sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation. default BERT encoder implementation.
output_range: the sequence output range, [0, output_range). Default setting
is to return the entire sequence output.
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
...@@ -161,6 +164,7 @@ def get_transformer_encoder(bert_config, ...@@ -161,6 +164,7 @@ def get_transformer_encoder(bert_config,
return networks.AlbertTransformerEncoder(**kwargs) return networks.AlbertTransformerEncoder(**kwargs)
else: else:
assert isinstance(bert_config, configs.BertConfig) assert isinstance(bert_config, configs.BertConfig)
kwargs['output_range'] = output_range
return networks.TransformerEncoder(**kwargs) return networks.TransformerEncoder(**kwargs)
...@@ -320,7 +324,8 @@ def classifier_model(bert_config, ...@@ -320,7 +324,8 @@ def classifier_model(bert_config,
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length) bert_encoder = get_transformer_encoder(
bert_config, max_seq_length, output_range=1)
return models.BertClassifier( return models.BertClassifier(
bert_encoder, bert_encoder,
num_classes=num_labels, num_classes=num_labels,
......
...@@ -39,6 +39,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -39,6 +39,8 @@ class Transformer(tf.keras.layers.Layer):
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer. attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -54,6 +56,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -54,6 +56,7 @@ class Transformer(tf.keras.layers.Layer):
intermediate_activation, intermediate_activation,
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
output_range=None,
kernel_initializer="glorot_uniform", kernel_initializer="glorot_uniform",
bias_initializer="zeros", bias_initializer="zeros",
kernel_regularizer=None, kernel_regularizer=None,
...@@ -69,6 +72,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -69,6 +72,7 @@ class Transformer(tf.keras.layers.Layer):
self._intermediate_activation = intermediate_activation self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
...@@ -168,6 +172,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -168,6 +172,8 @@ class Transformer(tf.keras.layers.Layer):
self._dropout_rate, self._dropout_rate,
"attention_dropout_rate": "attention_dropout_rate":
self._attention_dropout_rate, self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
...@@ -192,11 +198,16 @@ class Transformer(tf.keras.layers.Layer): ...@@ -192,11 +198,16 @@ class Transformer(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(target_tensor +
attention_output) attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
......
...@@ -29,8 +29,8 @@ from official.nlp.modeling.layers import transformer ...@@ -29,8 +29,8 @@ from official.nlp.modeling.layers import transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@parameterized.parameters(transformer.Transformer, @parameterized.named_parameters(('base', transformer.Transformer),
transformer.CompiledTransformer) ('xla', transformer.CompiledTransformer))
class TransformerLayerTest(keras_parameterized.TestCase): class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
...@@ -127,6 +127,33 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -127,6 +127,33 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data]) _ = model.predict([input_data, mask_data])
def test_layer_output_range(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
output_tensor = test_layer([input_data, mask_data])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer = transformer_cls(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = transformer_cls( test_layer = transformer_cls(
......
...@@ -61,6 +61,10 @@ class TransformerEncoder(network.Network): ...@@ -61,6 +61,10 @@ class TransformerEncoder(network.Network):
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
return_all_encoder_outputs: Whether to output sequence embedding outputs of return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
output_range: the sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full
output.
""" """
def __init__(self, def __init__(self,
...@@ -77,6 +81,7 @@ class TransformerEncoder(network.Network): ...@@ -77,6 +81,7 @@ class TransformerEncoder(network.Network):
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=None,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -98,6 +103,7 @@ class TransformerEncoder(network.Network): ...@@ -98,6 +103,7 @@ class TransformerEncoder(network.Network):
'attention_dropout_rate': attention_dropout_rate, 'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
'return_all_encoder_outputs': return_all_encoder_outputs, 'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range,
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -146,12 +152,17 @@ class TransformerEncoder(network.Network): ...@@ -146,12 +152,17 @@ class TransformerEncoder(network.Network):
attention_mask = layers.SelfAttentionMask()([data, mask]) attention_mask = layers.SelfAttentionMask()([data, mask])
encoder_outputs = [] encoder_outputs = []
for i in range(num_layers): for i in range(num_layers):
if i == num_layers - 1 and output_range is not None:
transformer_output_range = output_range
else:
transformer_output_range = None
layer = layers.Transformer( layer = layers.Transformer(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
intermediate_activation=activation, intermediate_activation=activation,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -114,7 +115,11 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -114,7 +115,11 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float16, pooled.dtype) self.assertAllEqual(tf.float16, pooled.dtype)
def test_network_invocation(self): @parameterized.named_parameters(
("all_sequence", None, 21),
("output_range", 1, 1),
)
def test_network_invocation(self, output_range, out_seq_len):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
...@@ -126,7 +131,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -126,7 +131,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length=sequence_length, sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types) type_vocab_size=num_types,
output_range=output_range)
self.assertTrue( self.assertTrue(
test_network._position_embedding_layer._use_dynamic_slicing) test_network._position_embedding_layer._use_dynamic_slicing)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -160,7 +166,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -160,7 +166,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
type_vocab_size=num_types) type_vocab_size=num_types)
self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing) self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
_ = model.predict([word_id_data, mask_data, type_id_data]) outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
...@@ -178,7 +185,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -178,7 +185,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
dropout_rate=0.05, dropout_rate=0.05,
attention_dropout_rate=0.22, attention_dropout_rate=0.22,
initializer="glorot_uniform", initializer="glorot_uniform",
return_all_encoder_outputs=False) return_all_encoder_outputs=False,
output_range=-1)
network = transformer_encoder.TransformerEncoder(**kwargs) network = transformer_encoder.TransformerEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment