Commit b0ccdb11 authored by Shixin Luo's avatar Shixin Luo
Browse files

resolve conflict with master

parents e61588cd 1611a8c5
......@@ -93,6 +93,7 @@ class EncoderScaffold(tf.keras.Model):
"kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
"""
def __init__(self,
......@@ -106,6 +107,7 @@ class EncoderScaffold(tf.keras.Model):
hidden_cls=layers.Transformer,
hidden_cfg=None,
return_all_layer_outputs=False,
dict_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
......@@ -117,6 +119,7 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
if embedding_cls:
......@@ -138,7 +141,7 @@ class EncoderScaffold(tf.keras.Model):
shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding(
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
......@@ -147,13 +150,13 @@ class EncoderScaffold(tf.keras.Model):
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.PositionEmbedding(
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding(
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
......@@ -200,7 +203,13 @@ class EncoderScaffold(tf.keras.Model):
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs:
if dict_outputs:
outputs = dict(
sequence_output=layer_output_data[-1],
pooled_output=cls_output,
encoder_outputs=layer_output_data,
)
elif return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
......@@ -219,6 +228,7 @@ class EncoderScaffold(tf.keras.Model):
'embedding_cfg': self._embedding_cfg,
'hidden_cfg': self._hidden_cfg,
'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs,
}
if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
......
......@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based text encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for EncoderScaffold network."""
from absl.testing import parameterized
import numpy as np
......@@ -218,16 +214,17 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
embedding_cfg=embedding_cfg,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
# Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
......@@ -237,7 +234,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data])
preds = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(preds["pooled_output"].shape, (3, hidden_size))
# Creates a EncoderScaffold with max_sequence_length != sequence_length
num_types = 7
......@@ -272,8 +270,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
_ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self):
......
......@@ -101,18 +101,18 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.max_sequence_length = max_sequence_length
self.dropout_rate = dropout_rate
self.word_embedding = layers.OnDeviceEmbedding(
self.word_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=initializer,
name='word_embedding')
self.type_embedding = layers.OnDeviceEmbedding(
self.type_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
use_one_hot=True,
initializer=initializer,
name='type_embedding')
self.pos_embedding = keras_nlp.PositionEmbedding(
self.pos_embedding = keras_nlp.layers.PositionEmbedding(
max_length=max_sequence_length,
initializer=initializer,
name='position_embedding')
......@@ -127,7 +127,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.dropout_rate,
name='embedding_dropout')
def call(self, input_ids, token_type_ids=None, training=False):
def call(self, input_ids, token_type_ids=None):
word_embedding_out = self.word_embedding(input_ids)
word_embedding_out = tf.concat(
[tf.pad(word_embedding_out[:, 1:], ((0, 0), (0, 1), (0, 0))),
......@@ -142,7 +142,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
type_embedding_out = self.type_embedding(token_type_ids)
embedding_out += type_embedding_out
embedding_out = self.layer_norm(embedding_out)
embedding_out = self.dropout_layer(embedding_out, training=training)
embedding_out = self.dropout_layer(embedding_out)
return embedding_out
......@@ -300,7 +300,6 @@ class TransformerLayer(tf.keras.layers.Layer):
def call(self,
input_tensor,
attention_mask=None,
training=False,
return_attention_scores=False):
"""Implementes the forward pass.
......@@ -309,7 +308,6 @@ class TransformerLayer(tf.keras.layers.Layer):
attention_mask: (optional) int32 tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
training: If the model is in training mode.
return_attention_scores: If return attention score.
Returns:
......@@ -326,7 +324,6 @@ class TransformerLayer(tf.keras.layers.Layer):
f'hidden size {self.hidden_size}'))
prev_output = input_tensor
# input bottleneck
dense_layer = self.block_layers['bottleneck_input'][0]
layer_norm = self.block_layers['bottleneck_input'][1]
......@@ -355,7 +352,6 @@ class TransformerLayer(tf.keras.layers.Layer):
key_tensor,
attention_mask,
return_attention_scores=True,
training=training
)
attention_output = layer_norm(attention_output + layer_input)
......@@ -375,7 +371,7 @@ class TransformerLayer(tf.keras.layers.Layer):
dropout_layer = self.block_layers['bottleneck_output'][1]
layer_norm = self.block_layers['bottleneck_output'][2]
layer_output = bottleneck(layer_output)
layer_output = dropout_layer(layer_output, training=training)
layer_output = dropout_layer(layer_output)
layer_output = layer_norm(layer_output + prev_output)
if return_attention_scores:
......@@ -406,8 +402,6 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
return_all_layers=False,
return_attention_score=False,
**kwargs):
"""Class initialization.
......@@ -438,8 +432,6 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
return_all_layers: If return all layer outputs.
return_attention_score: If return attention scores for each layer.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
......@@ -513,12 +505,11 @@ class MobileBERTEncoder(tf.keras.Model):
else:
self._pooler_layer = None
if return_all_layers:
outputs = [all_layer_outputs, first_token]
else:
outputs = [prev_output, first_token]
if return_attention_score:
outputs.append(all_attention_scores)
outputs = dict(
sequence_output=prev_output,
pooled_output=first_token,
encoder_outputs=all_layer_outputs,
attention_scores=all_attention_scores)
super(MobileBERTEncoder, self).__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
......
......@@ -32,7 +32,7 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input
class ModelingTest(parameterized.TestCase, tf.test.TestCase):
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self):
layer = mobile_bert_encoder.MobileBertEmbedding(10, 8, 2, 16)
......@@ -116,7 +116,9 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_output, pooler_output = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
layer_output, pooler_output = outputs['sequence_output'], outputs[
'pooled_output']
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_blocks)
......@@ -134,13 +136,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=100,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=True)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
all_layer_output, _ = test_network([word_ids, mask, type_ids])
outputs = test_network([word_ids, mask, type_ids])
all_layer_output = outputs['encoder_outputs']
self.assertIsInstance(all_layer_output, list)
self.assertLen(all_layer_output, num_blocks + 1)
......@@ -153,16 +155,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=False)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor = test_network(
[word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids],
[layer_out_tensor, pooler_out_tensor])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
input_seq = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
......@@ -170,13 +169,12 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=2)
layer_output, pooler_output = model.predict(
[input_seq, input_mask, token_type])
outputs = model.predict([input_seq, input_mask, token_type])
layer_output_shape = [1, sequence_length, hidden_size]
self.assertAllEqual(layer_output.shape, layer_output_shape)
pooler_output_shape = [1, hidden_size]
self.assertAllEqual(pooler_output.shape, pooler_output_shape)
sequence_output_shape = [1, sequence_length, hidden_size]
self.assertAllEqual(outputs['sequence_output'].shape, sequence_output_shape)
pooled_output_shape = [1, hidden_size]
self.assertAllEqual(outputs['pooled_output'].shape, pooled_output_shape)
def test_mobilebert_encoder_invocation_with_attention_score(self):
vocab_size = 100
......@@ -186,18 +184,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks,
return_all_layers=False,
return_attention_score=True)
num_blocks=num_blocks)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor, attention_out_tensor = test_network(
[word_ids, mask, type_ids])
model = tf.keras.Model(
[word_ids, mask, type_ids],
[layer_out_tensor, pooler_out_tensor, attention_out_tensor])
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
input_seq = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
......@@ -205,9 +198,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = generate_fake_input(
batch_size=1, seq_len=sequence_length, vocab_size=2)
_, _, attention_score_output = model.predict(
[input_seq, input_mask, token_type])
self.assertLen(attention_score_output, num_blocks)
outputs = model.predict([input_seq, input_mask, token_type])
self.assertLen(outputs['attention_scores'], num_blocks)
@parameterized.named_parameters(
('sequence_classification', models.BertClassifier, [None, 5]),
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based XLNet Model."""
from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer_xl
_SEG_ID_CLS = 2
def _create_causal_attention_mask(
seq_length,
memory_length,
dtype=tf.float32,
same_length=False):
"""Creates a causal attention mask with a single-sided context.
When applying the attention mask in `MultiHeadRelativeAttention`, the
attention scores are of shape `[(batch dimensions), S, S + M]`, where:
- S = sequence length.
- M = memory length.
In a simple case where S = 2, M = 1, here is a simple illustration of the
`attention_scores` matrix, where `a` represents an attention function:
token_0 [[a(token_0, mem_0) a(token_0, token_0) a(token_0, token_1)],
token_1 [a(token_1, mem_0) a(token_1, token_0) a(token_1, token_1)]]
mem_0 token_0 token_1
For uni-directional attention, we want to mask out values in the attention
scores that represent a(token_i, token_j) where j > i. We can achieve this by
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
Arguments:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
dtype: dtype of the mask.
same_length: bool, whether to use the same attention length for each token.
Returns:
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
"""
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype)
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1)
diagonal = tf.linalg.band_part(ones_matrix, 0, 0)
padding = tf.zeros([seq_length, memory_length], dtype=dtype)
causal_attention_mask = tf.concat(
[padding, upper_triangular - diagonal], 1)
if same_length:
lower_triangular = tf.linalg.band_part(ones_matrix, -1, 0)
strictly_lower_triangular = lower_triangular - diagonal
causal_attention_mask = tf.concat(
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular,
causal_attention_mask[:, seq_length:]], 1)
return causal_attention_mask
def _compute_attention_mask(
input_mask,
permutation_mask,
attention_type,
seq_length,
memory_length,
batch_size,
dtype=tf.float32):
"""Combines all input attention masks for XLNet.
In XLNet modeling, `0` represents tokens that can be attended, and `1`
represents tokens that cannot be attended.
For XLNet pre-training and fine tuning, there are a few masks used:
- Causal attention mask: If the attention type is unidirectional, then all
tokens after the current position cannot be attended to.
- Input mask: when generating data, padding is added to a max sequence length
to make all sequences the same length. This masks out real tokens (`0`) from
padding tokens (`1`).
- Permutation mask: during XLNet pretraining, the input sequence is factorized
into a factorization sequence `z`. During partial prediction, `z` is split
at a cutting point `c` (an index of the factorization sequence) and
prediction is only applied to all tokens after `c`. Therefore, tokens at
factorization positions `i` > `c` can be attended to and tokens at
factorization positions `i` <= `c` cannot be attended to.
This function broadcasts and combines all attention masks to produce the
query attention mask and the content attention mask.
Args:
input_mask: Tensor, the input mask related to padding. Input shape:
`(B, S)`.
permutation_mask: Tensor, the permutation mask used in partial prediction.
Input shape: `(B, S, S)`.
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
seq_length: int, the length of each sequence.
memory_length: int the length of memory blocks.
batch_size: int, the batch size.
dtype: The dtype of the masks.
Returns:
attention_mask, content_attention_mask: The position and context-based
attention masks and content attention masks, respectively.
"""
attention_mask = None
# `1` values mean do not attend to this position.
if attention_type == "uni":
causal_attention_mask = _create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length,
dtype=dtype)
causal_attention_mask = causal_attention_mask[None, None, :, :]
# `causal_attention_mask`: [1, 1, S, S + M]
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if input_mask is not None and permutation_mask is not None:
data_mask = input_mask[:, None, :] + permutation_mask
elif input_mask is not None and permutation_mask is None:
data_mask = input_mask[:, None, :]
elif input_mask is None and permutation_mask is not None:
data_mask = permutation_mask
else:
data_mask = None
# data_mask: [B, S, S] or [B, 1, S]
if data_mask is not None:
# All positions within state can be attended to.
state_mask = tf.zeros([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype)
# state_mask: [B, 1, M] or [B, S, M]
data_mask = tf.concat([state_mask, data_mask], 2)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if attention_type == "uni":
attention_mask = causal_attention_mask + data_mask[:, None, :, :]
else:
attention_mask = data_mask[:, None, :, :]
# Construct the content attention mask.
if attention_mask is not None:
attention_mask = tf.cast(attention_mask > 0, dtype=dtype)
non_tgt_mask = -tf.eye(seq_length, dtype=dtype)
non_tgt_mask = tf.concat(
[tf.zeros([seq_length, memory_length], dtype=dtype),
non_tgt_mask], axis=-1)
content_attention_mask = tf.cast(
(attention_mask + non_tgt_mask[None, None, :, :]) > 0,
dtype=dtype)
else:
content_attention_mask = None
return attention_mask, content_attention_mask
def _compute_segment_matrix(
segment_ids,
memory_length,
batch_size,
use_cls_mask):
"""Computes the segment embedding matrix.
XLNet introduced segment-based attention for attention calculations. This
extends the idea of relative encodings in Transformer XL by considering
whether or not two positions are within the same segment, rather than
which segments they come from.
This function generates a segment matrix by broadcasting provided segment IDs
in two different dimensions and checking where values are equal. This output
matrix shows `True` whenever two tokens are NOT in the same segment and
`False` whenever they are.
Args:
segment_ids: A Tensor of size `[B, S]` that represents which segment
each token belongs to.
memory_length: int, the length of memory blocks.
batch_size: int, the batch size.
use_cls_mask: bool, whether or not to introduce cls mask in
input sequences.
Returns:
A boolean Tensor of size `[B, S, S + M]`, where `True` means that two
tokens are NOT in the same segment, and `False` means they are in the same
segment.
"""
if segment_ids is None:
return None
memory_padding = tf.zeros([batch_size, memory_length], dtype=tf.int32)
padded_segment_ids = tf.concat([memory_padding, segment_ids], 1)
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
if use_cls_mask:
# `1` indicates not in the same segment.
# Target result: [B, S, S + M]
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
broadcasted_segment_class_indices = (
tf.equal(segment_ids,
tf.constant([_SEG_ID_CLS]))[:, :, None])
broadcasted_padded_class_indices = (
tf.equal(
padded_segment_ids,
tf.constant([_SEG_ID_CLS]))[:, None, :])
class_index_matrix = tf.logical_or(broadcasted_segment_class_indices,
broadcasted_padded_class_indices)
segment_matrix = tf.equal(segment_ids[:, :, None],
padded_segment_ids[:, None, :])
segment_matrix = tf.logical_or(class_index_matrix, segment_matrix)
else:
# TODO(allencwang) - address this legacy mismatch from `use_cls_mask`.
segment_matrix = tf.logical_not(
tf.equal(segment_ids[:, :, None], padded_segment_ids[:, None, :]))
return segment_matrix
def _compute_positional_encoding(
attention_type,
position_encoding_layer,
hidden_size,
batch_size,
total_length,
seq_length,
clamp_length,
bi_data,
dtype=tf.float32):
"""Computes the relative position encoding.
Args:
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
position_encoding_layer: An instance of `RelativePositionEncoding`.
hidden_size: int, the hidden size.
batch_size: int, the batch size.
total_length: int, the sequence length added to the memory length.
seq_length: int, the length of each sequence.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
dtype: the dtype of the encoding.
Returns:
A Tensor, representing the position encoding.
"""
freq_seq = tf.range(0, hidden_size, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=dtype)
if attention_type == "bi":
beg, end = total_length, -seq_length
elif attention_type == "uni":
beg, end = total_length, -1
else:
raise ValueError("Unknown `attention_type` {}.".format(attention_type))
if bi_data:
forward_position_sequence = tf.range(beg, end, -1.0)
backward_position_sequence = tf.range(-beg, -end, 1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(forward_position_sequence,
dtype=dtype)
backward_position_sequence = tf.cast(backward_position_sequence,
dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
backward_position_sequence = tf.clip_by_value(
backward_position_sequence,
-clamp_length,
clamp_length)
if batch_size is not None:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, batch_size // 2)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, batch_size // 2)
else:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, None)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, None)
relative_position_encoding = tf.concat(
[forward_positional_encoding, backward_positional_encoding], axis=0)
else:
forward_position_sequence = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(
forward_position_sequence, dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
relative_position_encoding = position_encoding_layer(
forward_position_sequence, batch_size)
return relative_position_encoding
class RelativePositionEncoding(tf.keras.layers.Layer):
"""Creates a relative positional encoding.
This layer creates a relative positional encoding as described in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
Rather than an absolute position embedding as in Transformer, this
formulation represents position as the relative distance between tokens using
sinusoidal positional embeddings.
Note: This layer is currently experimental.
Attributes:
hidden_size: The dimensionality of the input embeddings.
"""
def __init__(self, hidden_size, **kwargs):
super(RelativePositionEncoding, self).__init__(**kwargs)
self._hidden_size = hidden_size
self._inv_freq = 1.0 / (10000.0**(
tf.range(0, self._hidden_size, 2.0) / self._hidden_size))
def call(self, pos_seq, batch_size=None):
"""Implements call() for the layer.
Arguments:
pos_seq: A 1-D `Tensor`
batch_size: The optionally provided batch size that tiles the relative
positional encoding.
Returns:
The relative positional encoding of shape:
[batch_size, len(pos_seq), hidden_size] if batch_size is provided, else
[1, len(pos_seq), hidden_size].
"""
sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq)
relative_position_encoding = tf.concat([tf.sin(sinusoid_input),
tf.cos(sinusoid_input)], -1)
relative_position_encoding = relative_position_encoding[None, :, :]
if batch_size is not None:
relative_position_encoding = tf.tile(relative_position_encoding,
[batch_size, 1, 1])
return relative_position_encoding
@tf.keras.utils.register_keras_serializable(package="Text")
class XLNetBase(tf.keras.layers.Layer):
"""Base XLNet model.
Attributes:
vocab_size: int, the number of tokens in vocabulary.
num_layers: int, the number of layers.
hidden_size: int, the hidden size.
num_attention_heads: int, the number of attention heads.
head_size: int, the dimension size of each attention head.
inner_size: int, the hidden size in feed-forward layers.
dropout_rate: float, dropout rate.
attention_dropout_rate: float, dropout rate on attention probabilities.
attention_type: str, "uni" or "bi".
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
initializer: A tf initializer.
two_stream: bool, whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
tie_attention_biases: bool, whether or not to tie the biases together.
Usually set to `True`. Used for backwards compatibility.
memory_length: int, the number of tokens to cache.
same_length: bool, whether to use the same attention length for each
token.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
reuse_length: int, the number of tokens in the currect batch to be cached
and reused in the future.
inner_activation: str, "relu" or "gelu".
use_cls_mask: bool, whether or not cls mask is included in the
input sequences.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized
into two matrices in the shape of ["vocab_size", "embedding_width"] and
["embedding_width", "hidden_size"] ("embedding_width" is usually much
smaller than "hidden_size").
embedding_layer: The word embedding layer. `None` means we will create a
new embedding layer. Otherwise, we will reuse the given embedding layer.
This parameter is originally added for ELECTRA model which needs to tie
the generator embeddings with the discriminator embeddings.
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
attention_type,
bi_data,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
clamp_length=-1,
reuse_length=None,
inner_activation="relu",
use_cls_mask=False,
embedding_width=None,
**kwargs):
super(XLNetBase, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._attention_type = attention_type
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
self._bi_data = bi_data
self._clamp_length = clamp_length
self._use_cls_mask = use_cls_mask
self._segment_embedding = None
self._mask_embedding = None
self._embedding_width = embedding_width
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=embedding_width,
initializer=self._initializer,
dtype=tf.float32,
name="word_embedding")
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.embedding_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self.position_encoding = RelativePositionEncoding(self._hidden_size)
self._transformer_xl = transformer_xl.TransformerXL(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
head_size=head_size,
inner_size=inner_size,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=tie_attention_biases,
memory_length=memory_length,
reuse_length=reuse_length,
inner_activation=inner_activation,
name="transformer_xl")
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"attention_type":
self._attention_type,
"bi_data":
self._bi_data,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"clamp_length":
self._clamp_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
"use_cls_mask":
self._use_cls_mask,
"embedding_width":
self._embedding_width,
}
base_config = super(XLNetBase, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_embedding_lookup_table(self):
"""Returns the embedding layer weights."""
return self._embedding_layer.embeddings
def __call__(self,
input_ids,
segment_ids=None,
input_mask=None,
state=None,
permutation_mask=None,
target_mapping=None,
masked_tokens=None,
**kwargs):
# Uses dict to feed inputs into call() in order to keep state as a python
# list.
inputs = {
"input_ids": input_ids,
"segment_ids": segment_ids,
"input_mask": input_mask,
"state": state,
"permutation_mask": permutation_mask,
"target_mapping": target_mapping,
"masked_tokens": masked_tokens
}
return super(XLNetBase, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
input_ids = inputs["input_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
state = inputs["state"]
permutation_mask = inputs["permutation_mask"]
target_mapping = inputs["target_mapping"]
masked_tokens = inputs["masked_tokens"]
batch_size = tf.shape(input_ids)[0]
seq_length = input_ids.shape.as_list()[1]
memory_length = state[0].shape.as_list()[1] if state is not None else 0
total_length = memory_length + seq_length
if self._two_stream and masked_tokens is None:
raise ValueError("`masked_tokens` must be provided in order to "
"initialize the query stream in "
"`TwoStreamRelativeAttention`.")
if masked_tokens is not None and not self._two_stream:
logging.warning("`masked_tokens` is provided but `two_stream` is not "
"enabled. Please enable `two_stream` to enable two "
"stream attention.")
query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type=self._attention_type,
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
relative_position_encoding = _compute_positional_encoding(
attention_type=self._attention_type,
position_encoding_layer=self.position_encoding,
hidden_size=self._hidden_size,
batch_size=batch_size,
total_length=total_length,
seq_length=seq_length,
clamp_length=self._clamp_length,
bi_data=self._bi_data,
dtype=tf.float32)
relative_position_encoding = self.embedding_dropout(
relative_position_encoding)
if segment_ids is None:
segment_embedding = None
segment_matrix = None
else:
if self._segment_embedding is None:
self._segment_embedding = self.add_weight(
"seg_embed",
shape=[self._num_layers, 2, self._num_attention_heads,
self._head_size],
dtype=tf.float32,
initializer=self._initializer)
segment_embedding = self._segment_embedding
segment_matrix = _compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=self._use_cls_mask)
word_embeddings = self._embedding_layer(input_ids)
content_stream = self._dropout(word_embeddings)
if self._two_stream:
if self._mask_embedding is None:
self._mask_embedding = self.add_weight(
"mask_emb/mask_emb",
shape=[1, 1, self._hidden_size],
dtype=tf.float32)
if target_mapping is None:
masked_tokens = masked_tokens[:, :, None]
masked_token_embedding = (
masked_tokens * self._mask_embedding +
(1 - masked_tokens) * word_embeddings)
else:
masked_token_embedding = tf.tile(
self._mask_embedding,
[batch_size, tf.shape(target_mapping)[1], 1])
query_stream = self._dropout(masked_token_embedding)
else:
query_stream = None
return self._transformer_xl(
content_stream=content_stream,
query_stream=query_stream,
target_mapping=target_mapping,
state=state,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_embedding=segment_embedding,
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import xlnet_base
@keras_parameterized.run_all_keras_modes
class RelativePositionEncodingTest(keras_parameterized.TestCase):
def test_positional_embedding(self):
"""A low-dimensional example is tested.
With len(pos_seq)=2 and d_model=4:
pos_seq = [[1.], [0.]]
inv_freq = [1., 0.01]
pos_seq x inv_freq = [[1, 0.01], [0., 0.]]
pos_emb = [[sin(1.), sin(0.01), cos(1.), cos(0.01)],
[sin(0.), sin(0.), cos(0.), cos(0.)]]
= [[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]
"""
target = np.array([[[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]])
hidden_size = 4
pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
encoding_layer = xlnet_base.RelativePositionEncoding(
hidden_size=hidden_size)
encoding = encoding_layer(pos_seq, batch_size=None).numpy().astype(float)
self.assertAllClose(encoding, target)
class ComputePositionEncodingTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
attention_type=["uni", "bi"],
bi_data=[False, True],
))
def test_compute_position_encoding_smoke(self, attention_type, bi_data):
hidden_size = 4
batch_size = 4
total_length = 8
seq_length = 4
position_encoding_layer = xlnet_base.RelativePositionEncoding(
hidden_size=hidden_size)
encoding = xlnet_base._compute_positional_encoding(
attention_type=attention_type,
position_encoding_layer=position_encoding_layer,
hidden_size=hidden_size,
batch_size=batch_size,
total_length=total_length,
seq_length=seq_length,
clamp_length=2,
bi_data=bi_data,
dtype=tf.float32)
self.assertEqual(encoding.shape[0], batch_size)
self.assertEqual(encoding.shape[2], hidden_size)
class CausalAttentionMaskTests(tf.test.TestCase):
def test_casual_attention_mask_with_no_memory(self):
seq_length, memory_length = 3, 0
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 1, 1],
[0, 0, 1],
[0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_casual_attention_mask_with_memory(self):
seq_length, memory_length = 3, 2
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length)
expected_output = np.array([[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
def test_causal_attention_mask_with_same_length(self):
seq_length, memory_length = 3, 2
causal_attention_mask = xlnet_base._create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length,
same_length=True)
expected_output = np.array([[0, 0, 0, 1, 1],
[1, 0, 0, 0, 1],
[1, 1, 0, 0, 0]])
self.assertAllClose(causal_attention_mask, expected_output)
class MaskComputationTests(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(
use_input_mask=[False, True],
use_permutation_mask=[False, True],
attention_type=["uni", "bi"],
memory_length=[0, 4],
))
def test_compute_attention_mask_smoke(self,
use_input_mask,
use_permutation_mask,
attention_type,
memory_length):
"""Tests coverage and functionality for different configurations."""
batch_size = 2
seq_length = 8
if use_input_mask:
input_mask = tf.zeros(shape=(batch_size, seq_length))
else:
input_mask = None
if use_permutation_mask:
permutation_mask = tf.zeros(shape=(batch_size, seq_length, seq_length))
else:
permutation_mask = None
_, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type=attention_type,
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
expected_mask_shape = (batch_size, 1,
seq_length, seq_length + memory_length)
if use_input_mask or use_permutation_mask:
self.assertEqual(content_mask.shape, expected_mask_shape)
def test_no_input_masks(self):
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=None,
permutation_mask=None,
attention_type="uni",
seq_length=8,
memory_length=2,
batch_size=2,
dtype=tf.float32)
self.assertIsNone(query_mask)
self.assertIsNone(content_mask)
def test_input_mask_no_permutation(self):
"""Tests if an input mask is provided but not permutation.
In the case that only one of input mask or permutation mask is provided
and the attention type is bidirectional, the query mask should be
a broadcasted version of the provided mask.
Content mask should be a broadcasted version of the query mask, where the
diagonal is 0s.
"""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
permutation_mask = None
expected_query_mask = input_mask[None, None, :, :]
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_mask_no_input_mask(self):
"""Tests if a permutation mask is provided but not input."""
seq_length = 2
batch_size = 1
memory_length = 0
input_mask = None
permutation_mask = np.array([
[[0, 1],
[0, 1]],
])
expected_query_mask = permutation_mask[:, None, :, :]
expected_content_mask = np.array([[[
[0, 1],
[0, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_and_input_mask(self):
"""Tests if both an input and permutation mask are provided."""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 1, 1]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]])
expected_query_mask = np.array([[[
[1, 0, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 1, 1]]]])
expected_content_mask = np.array([[[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 1, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="bi",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
def test_permutation_input_uni_mask(self):
"""Tests if an input, permutation and causal mask are provided."""
seq_length = 4
batch_size = 1
memory_length = 0
input_mask = np.array([[0, 0, 0, 1]])
permutation_mask = np.array([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]])
expected_query_mask = np.array([[[
[1, 1, 1, 1],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1]]]])
expected_content_mask = np.array([[[
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]]]])
query_mask, content_mask = xlnet_base._compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type="uni",
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
self.assertAllClose(query_mask, expected_query_mask)
self.assertAllClose(content_mask, expected_content_mask)
class SegmentMatrixTests(tf.test.TestCase):
def test_no_segment_ids(self):
segment_matrix = xlnet_base._compute_segment_matrix(
segment_ids=None,
memory_length=2,
batch_size=1,
use_cls_mask=False)
self.assertIsNone(segment_matrix)
def test_basic(self):
batch_size = 1
memory_length = 0
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[False, False, True, False],
[False, False, True, False],
[True, True, False, True],
[False, False, True, False]
]])
segment_matrix = xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=False)
self.assertAllClose(segment_matrix, expected_segment_matrix)
def test_basic_with_memory(self):
batch_size = 1
memory_length = 1
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[True, False, False, True, False],
[True, False, False, True, False],
[True, True, True, False, True],
[True, False, False, True, False]
]]).astype(int)
segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=False), dtype=tf.uint8)
self.assertAllClose(segment_matrix, expected_segment_matrix)
def dont_test_basic_with_class_mask(self):
# TODO(allencwang) - this test should pass but illustrates the legacy issue
# of using class mask. Enable once addressed.
batch_size = 1
memory_length = 0
segment_ids = np.array([
[1, 1, 2, 1]
])
expected_segment_matrix = np.array([[
[False, False, True, False],
[False, False, True, False],
[True, True, False, True],
[False, False, True, False]
]]).astype(int)
segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=True), dtype=tf.uint8)
self.assertAllClose(segment_matrix, expected_segment_matrix)
class XLNetModelTests(tf.test.TestCase):
def _generate_data(self,
batch_size,
seq_length,
num_predictions=None):
"""Generates sample XLNet data for testing."""
sequence_shape = (batch_size, seq_length)
if num_predictions is not None:
target_mapping = tf.random.uniform(
shape=(batch_size, num_predictions, seq_length))
return {
"input_ids": np.random.randint(10, size=sequence_shape, dtype="int32"),
"segment_ids":
np.random.randint(2, size=sequence_shape, dtype="int32"),
"input_mask":
np.random.randint(2, size=sequence_shape).astype("float32"),
"permutation_mask":
np.random.randint(
2, size=(batch_size, seq_length, seq_length)).astype("float32"),
"target_mapping": target_mapping,
"masked_tokens": tf.random.uniform(shape=sequence_shape),
}
def test_xlnet_model(self):
batch_size = 2
seq_length = 8
num_predictions = 2
hidden_size = 4
xlnet_model = xlnet_base.XLNetBase(
vocab_size=32000,
num_layers=2,
hidden_size=hidden_size,
num_attention_heads=2,
head_size=2,
inner_size=2,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type="bi",
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
reuse_length=0,
inner_activation="relu")
input_data = self._generate_data(batch_size=batch_size,
seq_length=seq_length,
num_predictions=num_predictions)
model_output = xlnet_model(**input_data)
self.assertEqual(model_output[0].shape,
(batch_size, seq_length, hidden_size))
def test_get_config(self):
xlnet_model = xlnet_base.XLNetBase(
vocab_size=32000,
num_layers=12,
hidden_size=36,
num_attention_heads=12,
head_size=12,
inner_size=12,
dropout_rate=0.,
attention_dropout_rate=0.,
attention_type="bi",
bi_data=True,
initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
two_stream=False,
tie_attention_biases=True,
memory_length=0,
reuse_length=0,
inner_activation="relu")
config = xlnet_model.get_config()
new_xlnet = xlnet_base.XLNetBase.from_config(config)
self.assertEqual(config, new_xlnet.get_config())
if __name__ == "__main__":
tf.random.set_seed(0)
tf.test.main()
......@@ -14,12 +14,13 @@
# ==============================================================================
"""Test beam search helper methods."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.modeling.ops import beam_search
class BeamSearchHelperTests(tf.test.TestCase):
class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5])
......@@ -67,6 +68,41 @@ class BeamSearchHelperTests(tf.test.TestCase):
[[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
y)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_sequence_beam_search(self, padded_decode):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
[0.1, 0.8, 0.1]],
[[0.1, 0.8, 0.1], [0.3, 0.4, 0.3],
[0.2, 0.1, 0.7]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
if __name__ == "__main__":
def _get_test_symbols_to_logits_fn():
"""Test function that returns logits for next token."""
def symbols_to_logits_fn(_, i, cache):
logits = tf.cast(probabilities[:, i, :], tf.float32)
return logits, cache
return symbols_to_logits_fn
predictions, _ = beam_search.sequence_beam_search(
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
initial_ids=tf.zeros([1], dtype=tf.int32),
initial_cache=cache,
vocab_size=3,
beam_size=2,
alpha=0.6,
max_decode_length=3,
eos_id=9,
padded_decode=padded_decode,
dtype=tf.float32)
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
if __name__ == '__main__':
tf.test.main()
......@@ -104,14 +104,13 @@ Please first install TensorFlow 2 and Tensorflow Model Garden following the
```shell
$ python3 trainer.py \
--mode=train_and_eval \
--vocab=/path/to/bert_checkpoint/vocab.txt \
--init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
--params_override='init_from_bert2bert=false' \
--train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
--model_dir=/path/to/output/model \
--len_title=15 \
--len_passage=200 \
--max_num_articles=5 \
--num_nhnet_articles=5 \
--model_type=nhnet \
--train_batch_size=16 \
--train_steps=10000 \
......@@ -123,14 +122,13 @@ $ python3 trainer.py \
```shell
$ python3 trainer.py \
--mode=train_and_eval \
--vocab=/path/to/bert_checkpoint/vocab.txt \
--init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
--params_override='init_from_bert2bert=false' \
--train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
--model_dir=/path/to/output/model \
--len_title=15 \
--len_passage=200 \
--max_num_articles=5 \
--num_nhnet_articles=5 \
--model_type=nhnet \
--train_batch_size=1024 \
--train_steps=10000 \
......
......@@ -22,7 +22,6 @@ from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer
from official.nlp.transformer import model_utils as transformer_utils
......@@ -59,7 +58,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(
transformer.TransformerDecoderLayer(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
......
......@@ -15,11 +15,6 @@
# ==============================================================================
"""Evaluation for Bert2Bert."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
# Import libraries
from absl import logging
......@@ -114,7 +109,6 @@ def continuous_eval(strategy,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
model.global_step = global_step
@tf.function
def test_step(inputs):
......@@ -149,7 +143,7 @@ def continuous_eval(strategy,
eval_results = {}
for latest_checkpoint in tf.train.checkpoints_iterator(
model_dir, timeout=timeout):
checkpoint = tf.train.Checkpoint(model=model)
checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
checkpoint.restore(latest_checkpoint).expect_partial()
logging.info("Loaded checkpoint %s", latest_checkpoint)
......@@ -162,7 +156,7 @@ def continuous_eval(strategy,
metric.update_state(func(logits.numpy(), targets.numpy()))
with eval_summary_writer.as_default():
step = model.global_step.numpy()
step = global_step.numpy()
for metric, _ in metrics_and_funcs:
eval_results[metric.name] = metric.result().numpy().astype(float)
tf.summary.scalar(
......
......@@ -27,13 +27,13 @@ from absl import flags
from absl import logging
from six.moves import zip
import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.nlp.nhnet import evaluation
from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
opt = optimizer.create_optimizer(params)
trainer = Trainer(model, params)
model.global_step = opt.iterations
trainer.compile(
optimizer=opt,
......@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
checkpoint = tf.train.Checkpoint(
model=model, optimizer=opt, global_step=opt.iterations)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=FLAGS.model_dir,
max_to_keep=10,
step_counter=model.global_step,
step_counter=opt.iterations,
checkpoint_interval=FLAGS.checkpoint_interval)
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
......@@ -185,7 +185,7 @@ def run():
if FLAGS.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge()
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based bigbird attention layer."""
import numpy as np
import tensorflow as tf
MAX_SEQ_LEN = 4096
def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
Returns:
float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4,
from_block_size, 3*to_block_size].
"""
exp_blocked_to_pad = tf.concat([
to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:,
3:-1]
], 2)
band_mask = tf.einsum("BLQ,BLK->BLQK", from_blocked_mask[:, 2:-2],
exp_blocked_to_pad)
band_mask = tf.expand_dims(band_mask, 1)
return band_mask
def bigbird_block_rand_mask(from_seq_length,
to_seq_length,
from_block_size,
to_block_size,
num_rand_blocks,
last_idx=-1):
"""Create adjacency list of random attention.
Args:
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
num_rand_blocks: int. Number of random chunks per row.
last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
if positive then num_rand_blocks blocks choosen only upto last_idx.
Returns:
adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
"""
assert from_seq_length//from_block_size == to_seq_length//to_block_size, \
"Error the number of blocks needs to be same!"
rand_attn = np.zeros(
(from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
last = to_seq_length // to_block_size - 1
if last_idx > (2 * to_block_size):
last = (last_idx // to_block_size) - 1
r = num_rand_blocks # shorthand
for i in range(1, from_seq_length // from_block_size - 1):
start = i - 2
end = i
if i == 1:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r]
elif i == 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r]
elif i == from_seq_length // from_block_size - 3:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
# Missing -3: should have been sliced till last-3
elif i == from_seq_length // from_block_size - 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
# Missing -4: should have been sliced till last-4
else:
if start > last:
start = last
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
elif (end + 1) == last:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
else:
rand_attn[i - 1, :] = np.random.permutation(
np.concatenate((middle_seq[:start], middle_seq[end + 1:last])))[:r]
return rand_attn
def create_rand_mask_from_inputs(from_blocked_mask, to_blocked_mask, rand_attn,
num_attention_heads, num_rand_blocks,
batch_size, from_seq_length, from_block_size):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
from_block_size: int. size of block in from sequence.
Returns:
float Tensor of shape [batch_size, num_attention_heads,
from_seq_length//from_block_size-2,
from_block_size, num_rand_blocks*to_block_size].
"""
num_windows = from_seq_length // from_block_size - 2
rand_mask = tf.reshape(
tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [
batch_size, num_attention_heads, num_windows,
num_rand_blocks * from_block_size
])
rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1],
rand_mask)
return rand_mask
def bigbird_block_sparse_attention(
query_layer, key_layer, value_layer, band_mask, from_mask, to_mask,
from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads,
num_rand_blocks, size_per_head, batch_size, from_seq_length, to_seq_length,
from_block_size, to_block_size):
"""BigBird attention sparse calculation using blocks in linear time.
Assumes from_seq_length//from_block_size == to_seq_length//to_block_size.
Args:
query_layer: float Tensor of shape [batch_size, num_attention_heads,
from_seq_length, size_per_head]
key_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
value_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
band_mask: (optional) int32 Tensor of shape [batch_size, 1,
from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. The
values should be 1 or 0. The attention scores will effectively be set to
-infinity for any positions in the mask that are 0, and will be unchanged
for positions that are 1.
from_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length,
1]. The values should be 1 or 0. The attention scores will effectively be
set to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1, to_seq_length].
The values should be 1 or 0. The attention scores will effectively be set
to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size]. Same as from_mask,
just reshaped.
to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size]. Same as to_mask, just
reshaped.
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
size_per_head: int. Size of each attention head.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
Returns:
float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
size_per_head].
"""
rand_attn = tf.expand_dims(rand_attn, 0)
rand_attn = tf.repeat(rand_attn, batch_size, 0)
rand_mask = create_rand_mask_from_inputs(
from_blocked_mask,
to_blocked_mask,
rand_attn,
num_attention_heads,
num_rand_blocks,
batch_size,
from_seq_length,
from_block_size,
)
# Define shorthands
h = num_attention_heads
r = num_rand_blocks
d = size_per_head
b = batch_size
m = from_seq_length
n = to_seq_length
wm = from_block_size
wn = to_block_size
query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
blocked_query_matrix = tf.reshape(query_layer, (b, h, m // wm, wm, -1))
blocked_key_matrix = tf.reshape(key_layer, (b, h, n // wn, wn, -1))
blocked_value_matrix = tf.reshape(value_layer, (b, h, n // wn, wn, -1))
gathered_key = tf.reshape(
tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"),
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
gathered_value = tf.reshape(
tf.gather(
blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"),
(b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1]
first_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0],
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
first_product = tf.multiply(first_product, 1.0 / np.sqrt(d))
first_product += (1.0 - tf.cast(to_mask, dtype=tf.float32)) * -10000.0
first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n]
first_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", first_attn_weights,
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
first_context_layer = tf.expand_dims(first_context_layer, 2)
second_key_mat = tf.concat([
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1],
blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :,
-1], gathered_key[:, :, 0]
], 2) # [b, h, (4+r)*wn, -1]
second_value_mat = tf.concat([
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1],
blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1],
gathered_value[:, :, 0]
], 2) # [b, h, (4+r)*wn, -1]
second_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_seq_pad = tf.concat([
to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32)
], 3)
second_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, 0]], 3)
second_product = tf.multiply(second_product, 1.0 / np.sqrt(d))
second_product += (1.0 -
tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0
second_attn_weights = tf.nn.softmax(second_product) # [b , h, wm, (4+r)*wn]
second_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", second_attn_weights, second_value_mat
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_context_layer = tf.expand_dims(second_context_layer, 2)
exp_blocked_key_matrix = tf.concat([
blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2],
blocked_key_matrix[:, :, 3:-1]
], 3) # [b, h, m//wm-4, 3*wn, -1]
exp_blocked_value_matrix = tf.concat([
blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2],
blocked_value_matrix[:, :, 3:-1]
], 3) # [b, h, m//wm-4, 3*wn, -1]
middle_query_matrix = blocked_query_matrix[:, :, 2:-2]
inner_band_product = tf.einsum(
"BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix
) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, 3*wn]
inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d))
rand_band_product = tf.einsum(
"BHLQD,BHLKD->BHLQK", middle_query_matrix,
gathered_key[:, :,
1:-1]) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, r*wn]
rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d))
first_band_product = tf.einsum(
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0]
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d))
last_band_product = tf.einsum(
"BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1]
) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d))
inner_band_product += (1.0 - band_mask) * -10000.0
first_band_product += (1.0 -
tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0
last_band_product += (1.0 -
tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0
rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0
band_product = tf.concat([
first_band_product, inner_band_product, rand_band_product,
last_band_product
], -1) # [b, h, m//wm-4, wm, (5+r)*wn]
attn_weights = tf.nn.softmax(band_product) # [b, h, m//wm-4, wm, (5+r)*wn]
context_layer = tf.einsum(
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :,
wn:4 * wn], exp_blocked_value_matrix
) # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :,
4 * wn:-wn], gathered_value[:, :, 1:-1]
) # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn],
blocked_value_matrix[:, :, 0]
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
context_layer += tf.einsum(
"BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :,
-wn:], blocked_value_matrix[:, :, -1]
) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
second_last_key_mat = tf.concat([
blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3],
blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1],
gathered_key[:, :, -1]
], 2) # [b, h, (4+r)*wn, -1]
second_last_value_mat = tf.concat([
blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3],
blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1],
gathered_value[:, :, -1]
], 2) # [b, h, (4+r)*wn, -1]
second_last_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_last_seq_pad = tf.concat([
to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32)
], 3)
second_last_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, -1]], 3)
second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d))
second_last_product += (
1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
second_last_attn_weights = tf.nn.softmax(
second_last_product) # [b, h, wm, (4+r)*wn]
second_last_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat
) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_last_context_layer = tf.expand_dims(second_last_context_layer, 2)
last_product = tf.einsum(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1],
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
last_product = tf.multiply(last_product, 1.0 / np.sqrt(d))
last_product += (1.0 - to_mask) * -10000.0
last_attn_weights = tf.nn.softmax(last_product) # [b, h, wm, n]
last_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", last_attn_weights,
value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
last_context_layer = tf.expand_dims(last_context_layer, 2)
context_layer = tf.concat([
first_context_layer, second_context_layer, context_layer,
second_last_context_layer, last_context_layer
], 2)
context_layer = tf.reshape(context_layer, (b, h, m, -1)) * from_mask
context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
return context_layer
class BigBirdMasks(tf.keras.layers.Layer):
"""Creates bigbird attention masks."""
def __init__(self, block_size, **kwargs):
super().__init__(**kwargs)
self._block_size = block_size
def call(self, inputs):
encoder_shape = tf.shape(inputs)
batch_size, seq_length = encoder_shape[0], encoder_shape[1]
# reshape and cast for blocking
inputs = tf.cast(inputs, dtype=tf.float32)
blocked_encoder_mask = tf.reshape(
inputs, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1))
encoder_to_mask = tf.reshape(inputs, (batch_size, 1, 1, seq_length))
band_mask = create_band_mask_from_inputs(blocked_encoder_mask,
blocked_encoder_mask)
return [band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask]
@tf.keras.utils.register_keras_serializable(package="Text")
class BigBirdAttention(tf.keras.layers.MultiHeadAttention):
"""BigBird, a sparse attention mechanism.
This layer follows the paper "Big Bird: Transformers for Longer Sequences"
(https://arxiv.org/abs/2007.14062).
It reduces this quadratic dependency of attention
computation to linear.
Arguments are the same as `MultiHeadAttention` layer.
"""
def __init__(self,
num_rand_blocks=3,
from_block_size=64,
to_block_size=64,
max_rand_mask_length=MAX_SEQ_LEN,
seed=None,
**kwargs):
super().__init__(**kwargs)
self._num_rand_blocks = num_rand_blocks
self._from_block_size = from_block_size
self._to_block_size = to_block_size
self._seed = seed
# Generates random attention.
np.random.seed(self._seed)
# pylint: disable=g-complex-comprehension
rand_attn = [
bigbird_block_rand_mask(
max_rand_mask_length,
max_rand_mask_length,
from_block_size,
to_block_size,
num_rand_blocks,
last_idx=1024) for _ in range(self._num_heads)
]
# pylint: enable=g-complex-comprehension
rand_attn = np.stack(rand_attn, axis=0)
self.rand_attn = tf.constant(rand_attn, dtype=tf.int32)
def _compute_attention(self, query, key, value, attention_mask=None):
(band_mask, encoder_from_mask, encoder_to_mask,
blocked_encoder_mask) = attention_mask
query_shape = tf.shape(query)
from_seq_length = query_shape[1]
to_seq_length = tf.shape(key)[1]
rand_attn = self.rand_attn[:, :(from_seq_length // self._from_block_size -
2)]
return bigbird_block_sparse_attention(
query,
key,
value,
band_mask,
encoder_from_mask,
encoder_to_mask,
blocked_encoder_mask,
blocked_encoder_mask,
num_attention_heads=self._num_heads,
num_rand_blocks=self._num_rand_blocks,
size_per_head=self._key_dim,
batch_size=query_shape[0],
from_seq_length=from_seq_length,
to_seq_length=to_seq_length,
from_block_size=self._from_block_size,
to_block_size=self._to_block_size,
rand_attn=rand_attn)
def call(self, query, value, key=None, attention_mask=None, **kwargs):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output = self._compute_attention(query, key, value,
attention_mask)
attention_output.set_shape([None, None, self._num_heads, self._key_dim])
attention_output = self._output_dense(attention_output)
return attention_output
def get_config(self):
config = {
"num_rand_blocks": self._num_rand_blocks,
"from_block_size": self._from_block_size,
"to_block_size": self._to_block_size,
"seed": self._seed
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,46 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for image classification task."""
"""Tests for official.nlp.projects.bigbird.attention."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.vision import beta
from official.vision.beta.tasks import image_classification as img_cls_task
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet'),
('revnet_imagenet'))
def test_task(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2
task = img_cls_task.ImageClassificationTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
self.assertIn('loss', logs)
self.assertIn('accuracy', logs)
self.assertIn('top_5_accuracy', logs)
logs = task.validation_step(next(iterator), model, metrics=metrics)
self.assertIn('loss', logs)
self.assertIn('accuracy', logs)
self.assertIn('top_5_accuracy', logs)
from official.nlp.projects.bigbird import attention
class BigbirdAttentionTest(tf.test.TestCase):
def test_attention(self):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
block_size = 64
mask_layer = attention.BigBirdMasks(block_size=block_size)
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = mask_layer(encoder_inputs_mask)
test_layer = attention.BigBirdAttention(
num_heads=num_heads,
key_dim=key_dim,
from_block_size=block_size,
to_block_size=block_size,
seed=0)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
output = test_layer(
query=query,
value=value,
attention_mask=masks)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
def test_config(self):
num_heads = 12
key_dim = 64
block_size = 64
test_layer = attention.BigBirdAttention(
num_heads=num_heads,
key_dim=key_dim,
from_block_size=block_size,
to_block_size=block_size,
seed=0)
print(test_layer.get_config())
new_layer = attention.BigBirdAttention.from_config(
test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.projects.bigbird import attention
@tf.keras.utils.register_keras_serializable(package='Text')
class BigBirdEncoder(tf.keras.Model):
"""Transformer-based encoder network with BigBird attentions.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN,
type_vocab_size=16,
intermediate_size=3072,
block_size=64,
num_rand_blocks=3,
activation=activations.gelu,
dropout_rate=0.1,
attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
embedding_width=None,
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'block_size': block_size,
'num_rand_blocks': num_rand_blocks,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
'embedding_width': embedding_width,
}
word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = tf.keras.layers.Dropout(rate=dropout_rate)(embeddings)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
self._transformer_layers = []
data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(mask)
encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers):
layer = layers.TransformerScaffold(
num_attention_heads,
intermediate_size,
activation,
attention_cls=attention.BigBirdAttention,
attention_cfg=dict(
num_heads=num_attention_heads,
key_dim=attn_head_dim,
kernel_initializer=initializer,
from_block_size=block_size,
to_block_size=block_size,
num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length,
seed=i),
dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate,
kernel_initializer=initializer)
self._transformer_layers.append(layer)
data = layer([data, masks])
encoder_outputs.append(data)
outputs = dict(
sequence_output=encoder_outputs[-1], encoder_outputs=encoder_outputs)
super().__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return self._config_dict
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......@@ -12,65 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for roi_sampler.py."""
"""Tests for official.nlp.projects.bigbird.encoder."""
# Import libraries
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling.layers import roi_sampler
class ROISamplerTest(tf.test.TestCase):
def test_roi_sampler(self):
boxes_np = np.array(
[[[0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5],
[5, 5, 10, 10], [7.5, 7.5, 12.5, 12.5]]])
boxes = tf.constant(boxes_np, dtype=tf.float32)
gt_boxes_np = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5],
[-1, -1, -1, -1]]])
gt_boxes = tf.constant(gt_boxes_np, dtype=tf.float32)
gt_classes_np = np.array([[2, 10, -1]])
gt_classes = tf.constant(gt_classes_np, dtype=tf.int32)
generator = roi_sampler.ROISampler(
mix_gt_boxes=True,
num_sampled_rois=2,
foreground_fraction=0.5,
foreground_iou_threshold=0.5,
background_iou_high_threshold=0.5,
background_iou_low_threshold=0.0)
# Runs on TPU.
strategy = tf.distribute.experimental.TPUStrategy()
with strategy.scope():
_ = generator(boxes, gt_boxes, gt_classes)
# Runs on CPU.
_ = generator(boxes, gt_boxes, gt_classes)
def test_serialize_deserialize(self):
kwargs = dict(
mix_gt_boxes=True,
num_sampled_rois=512,
foreground_fraction=0.25,
foreground_iou_threshold=0.5,
background_iou_high_threshold=0.5,
background_iou_low_threshold=0.5,
)
generator = roi_sampler.ROISampler(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(generator.get_config(), expected_config)
new_generator = roi_sampler.ROISampler.from_config(
generator.get_config())
self.assertAllEqual(generator.get_config(), new_generator.get_config())
if __name__ == '__main__':
from official.nlp.projects.bigbird import encoder
class BigBirdEncoderTest(tf.test.TestCase):
def test_encoder(self):
sequence_length = 1024
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(2, size=(batch_size, sequence_length))
outputs = network([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs["sequence_output"].shape,
(batch_size, sequence_length, 768))
def test_save_restore(self):
sequence_length = 1024
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(2, size=(batch_size, sequence_length))
inputs = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
ref_outputs = network(inputs)
model_path = self.get_temp_dir() + "/model"
network.save(model_path)
loaded = tf.keras.models.load_model(model_path)
outputs = loaded(inputs)
self.assertAllClose(outputs["sequence_output"],
ref_outputs["sequence_output"])
if __name__ == "__main__":
tf.test.main()
......@@ -63,31 +63,33 @@ class MaskedLMTask(base_task.Task):
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_output'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
with tf.name_scope('MaskedLMTask/losses'):
metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['mlm_logits'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
......@@ -128,14 +130,15 @@ class MaskedLMTask(base_task.Task):
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_output'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['next_sentence'])
with tf.name_scope('MaskedLMTask/process_metrics'):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'], model_outputs['mlm_logits'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
......
......@@ -16,25 +16,51 @@
"""Common utils for tasks."""
from typing import Any, Callable
from absl import logging
import orbit
import tensorflow as tf
import tensorflow_hub as hub
def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
"""Gets an encoder from hub."""
def get_encoder_from_hub(hub_model) -> tf.keras.Model:
"""Gets an encoder from hub.
Args:
hub_model: A tfhub model loaded by `hub.load(...)`.
Returns:
A tf.keras.Model.
"""
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
hub_layer = hub.KerasLayer(hub_module, trainable=True)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
hub_layer = hub.KerasLayer(hub_model, trainable=True)
output_dict = {}
dict_input = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
# The legacy hub model takes a list as input and returns a Tuple of
# `pooled_output` and `sequence_output`, while the new hub model takes dict
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_output_signature = hub_model.signatures['serving_default'].outputs
if len(hub_output_signature) == 2:
logging.info('Use the legacy hub module with list as input/output.')
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
output_dict['pooled_output'] = pooled_output
output_dict['sequence_output'] = sequence_output
else:
logging.info('Use the new hub module with dict as input/output.')
output_dict = hub_layer(dict_input)
return tf.keras.Model(inputs=dict_input, outputs=output_dict)
def predict(predict_step_fn: Callable[[Any], Any],
......
......@@ -20,6 +20,7 @@ from absl import flags
import gin
from official.core import train_utils
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
......@@ -27,7 +28,6 @@ from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -48,11 +48,12 @@ def main(_):
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
......
......@@ -15,9 +15,10 @@
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
import gc
import os
import time
from typing import Mapping, Any
from typing import Any, Mapping, Optional
from absl import app
from absl import flags
......@@ -28,38 +29,44 @@ import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode.
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint is discovered, loads the checkpoint, finetune the model by
training it (probably on another dataset or with another task), then
evaluate the finetuned model.
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
......@@ -77,7 +84,7 @@ def run_continuous_finetune(
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......@@ -95,10 +102,24 @@ def run_continuous_finetune(
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout):
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
......@@ -139,6 +160,13 @@ def run_continuous_finetune(
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc.collect()
if run_post_eval:
return eval_metrics
......@@ -150,7 +178,7 @@ def main(_):
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir, FLAGS.pretrain_steps)
if __name__ == '__main__':
......
......@@ -15,10 +15,9 @@
# ==============================================================================
import os
# Import libraries
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
from official.common import flags as tfm_flags
from official.core import task_factory
......@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS
tfm_flags.define_flags()
class MainContinuousFinetuneTest(tf.test.TestCase):
class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(MainContinuousFinetuneTest, self).setUp()
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver
def testTrainCtl(self):
@parameterized.parameters(None, 1)
def testTrainCtl(self, pretrain_steps):
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
......@@ -81,7 +80,11 @@ class MainContinuousFinetuneTest(tf.test.TestCase):
params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
FLAGS.mode, params, FLAGS.model_dir, run_post_eval=True)
FLAGS.mode,
params,
FLAGS.model_dir,
run_post_eval=True,
pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment