Commit 55b5100e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327149314
parent 92ea3959
...@@ -11,6 +11,12 @@ assemble new layers, networks, or models. ...@@ -11,6 +11,12 @@ assemble new layers, networks, or models.
* [CachedAttention](attention.py) implements an attention layer with cache * [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding. used for auto-agressive decoding.
* [MatMulWithMargin](mat_mul_with_margin.py) implements a matrix
multiplication with margin layer used for training retrieval / ranking
tasks, as described in ["Improving Multilingual Sentence Embedding using
Bi-directional Dual Encoder with Additive Margin
Softmax"](https://www.ijcai.org/Proceedings/2019/0746.pdf).
* [MultiChannelAttention](multi_channel_attention.py) implements an variant of * [MultiChannelAttention](multi_channel_attention.py) implements an variant of
multi-head attention which can be used to merge multiple streams for multi-head attention which can be used to merge multiple streams for
cross-attentions. cross-attentions.
...@@ -24,8 +30,8 @@ assemble new layers, networks, or models. ...@@ -24,8 +30,8 @@ assemble new layers, networks, or models.
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up * [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and of self multi-head attention, cross multi-head attention and feedforward
feedforward network. network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with * [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in ReZero described in
......
...@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
from official.nlp.modeling.layers.multi_channel_attention import * from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dot product with margin layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from typing import Tuple
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text')
class MatMulWithMargin(tf.keras.layers.Layer):
"""This layer computs a dot product matrix given two encoded inputs.
Arguments:
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin value between the positive and negative examples
when doing training.
"""
def __init__(self,
logit_scale=1.0,
logit_margin=0.0,
**kwargs):
super(MatMulWithMargin, self).__init__(**kwargs)
self.logit_scale = logit_scale
self.logit_margin = logit_margin
def call(self, left_encoded: tf.Tensor,
right_encoded: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
batch_size = tf_utils.get_shape_list(
left_encoded, name='sequence_output_tensor')[0]
# Left -> Right dot product.
left_dot_products = tf.matmul(
left_encoded, right_encoded, transpose_b=True)
self.left_logits = self.logit_scale * (
left_dot_products - self.logit_margin * tf.eye(batch_size))
# Right -> Left dot product.
self.right_logits = tf.transpose(self.left_logits)
return (self.left_logits, self.right_logits)
def get_config(self):
config = {
'logit_scale': self.logit_scale,
'logit_margin': self.logit_margin}
config.update(super(MatMulWithMargin, self).get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mat_mul_with_margin layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import mat_mul_with_margin
class MatMulWithMarginTest(keras_parameterized.TestCase):
def test_layer_invocation(self):
"""Validate that the Keras object can be created and invoked."""
input_width = 512
test_layer = mat_mul_with_margin.MatMulWithMargin()
# Create a 2-dimensional input (the first dimension is implicit).
left_encoded = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
right_encoded = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
left_logits, right_logits = test_layer(left_encoded, right_encoded)
# Validate that the outputs are of the expected shape.
expected_output_shape = [None, None]
self.assertEqual(expected_output_shape, left_logits.shape.as_list())
self.assertEqual(expected_output_shape, right_logits.shape.as_list())
def test_serialize_deserialize(self):
# Create a layer object that sets all of its config options.
layer = mat_mul_with_margin.MatMulWithMargin()
# Create another layer object from the first object's config.
new_layer = mat_mul_with_margin.MatMulWithMargin.from_config(
layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
tf.test.main()
...@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks. ...@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
* [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a * [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a
classification head using the Masked LM and Classification networks, classification head using the Masked LM and Classification networks,
respectively. respectively.
* [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for
retrieval tasks.
...@@ -17,4 +17,5 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier ...@@ -17,4 +17,5 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import * from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
# Import libraries
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf.keras.Model):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Arguments:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either 'logits' or
'predictions'. If set to 'predictions', it will output the embedding
producted by transformer network.
"""
def __init__(self,
network: tf.keras.Model,
max_seq_length: int = 32,
normalize: bool = True,
logit_scale: float = 1.0,
logit_margin: float = 0.0,
output: str = 'logits',
**kwargs) -> None:
self._self_setattr_tracking = False
self._config = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
self.network = network
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
_, left_encoded = network(left_inputs)
if normalize:
left_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(left_encoded)
if output == 'logits':
right_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
right_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
right_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids]
_, right_encoded = network(right_inputs)
if normalize:
right_encoded = tf.keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(right_encoded)
dot_products = layers.MatMulWithMargin(logit_scale=logit_scale,
logit_margin=logit_margin,
name='dot_product')
inputs = [left_word_ids, left_mask, left_type_ids, right_word_ids,
right_mask, right_type_ids]
left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = [left_logits, right_logits]
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
outputs = left_encoded
else:
raise ValueError('output type %s is not supported' % output)
super(DualEncoder, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.network)
return items
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dual encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import dual_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class DualEncoderTest(keras_parameterized.TestCase):
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder(self, hidden_size, output):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the dual encoder model.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, hidden_size=hidden_size,
sequence_length=sequence_length)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
left_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
left_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
right_type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
if output == 'logits':
outputs = dual_encoder_model([
left_word_ids, left_mask, left_type_ids,
right_word_ids, right_mask, right_type_ids])
left_encoded, _ = outputs
elif output == 'predictions':
left_encoded = dual_encoder_model([
left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape.
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.)
sequence_length = 2
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network.
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output=output)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
# Invoke the model model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
if output == 'logits':
_ = dual_encoder_model(
[word_ids, mask, type_ids, word_ids, mask, type_ids])
elif output == 'predictions':
_ = dual_encoder_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
"""Validate that the dual encoder model can be serialized / deserialized."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.)
sequence_length = 32
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network. (Note that all the
# args are different, so we can catch any serialization mismatches.)
dual_encoder_model = dual_encoder.DualEncoder(
test_network, max_seq_length=sequence_length, output='predictions')
# Create another dual encoder model via serialization and deserialization.
config = dual_encoder_model.get_config()
new_dual_encoder = dual_encoder.DualEncoder.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_dual_encoder.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(dual_encoder_model.get_config(),
new_dual_encoder.get_config())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment