Unverified Commit 5ffcc5b6 authored by Anirudh Vegesana's avatar Anirudh Vegesana Committed by GitHub
Browse files

Merge branch 'purdue-yolo' into detection_generator_pr

parents 0b81a843 76e0c014
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import slide_optimizer from official.modeling.optimization import slide_optimizer
from official.modeling.optimization import adafactor_optimizer
from official.modeling.optimization import ema_optimizer from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lars_optimizer from official.modeling.optimization import lars_optimizer
from official.modeling.optimization import lr_schedule from official.modeling.optimization import lr_schedule
...@@ -34,14 +35,15 @@ OPTIMIZERS_CLS = { ...@@ -34,14 +35,15 @@ OPTIMIZERS_CLS = {
'rmsprop': tf.keras.optimizers.RMSprop, 'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS, 'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad, 'adagrad': tf.keras.optimizers.Adagrad,
'slide': slide_optimizer.SLIDE 'slide': slide_optimizer.SLIDE,
'adafactor': adafactor_optimizer.Adafactor,
} }
LR_CLS = { LR_CLS = {
'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay, 'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
'polynomial': tf.keras.optimizers.schedules.PolynomialDecay, 'polynomial': lr_schedule.PolynomialDecayWithOffset,
'exponential': tf.keras.optimizers.schedules.ExponentialDecay, 'exponential': lr_schedule.ExponentialDecayWithOffset,
'cosine': tf.keras.experimental.CosineDecay, 'cosine': lr_schedule.CosineDecayWithOffset,
'power': lr_schedule.DirectPowerDecay, 'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay, 'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset, 'power_with_offset': lr_schedule.PowerDecayWithOffset,
......
...@@ -14,29 +14,16 @@ ...@@ -14,29 +14,16 @@
"""Functions and classes related to training performance.""" """Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale='dynamic', loss_scale=None):
use_experimental_api=False):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
if use_float16: if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api. if loss_scale in (None, 'dynamic'):
if use_experimental_api:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
elif loss_scale == 'dynamic':
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else: else:
# loss_scale is a number. We interpret that as a fixed loss scale. # loss_scale is a number. We interpret that as a fixed loss scale.
...@@ -52,34 +39,17 @@ def configure_optimizer(optimizer, ...@@ -52,34 +39,17 @@ def configure_optimizer(optimizer,
return optimizer return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None, def set_mixed_precision_policy(dtype, loss_scale=None):
use_experimental_api=False): """Sets the global `tf.keras.mixed_precision.Policy`."""
"""Sets mix precision policy.""" # TODO(b/191894773): Remove loss_scale argument
if use_experimental_api: assert loss_scale is None, (
logging.warning('Passing use_experimental_api=True is deprecated. The ' 'The loss_scale argument must be None. The argument exists for '
'argument will be removed in the future.') 'historical reasons and will be removed soon.')
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
'when creating the LossScaleOptimizer instead.'
)
if dtype == tf.float16: if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
else:
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
elif dtype == tf.bfloat16: elif dtype == tf.bfloat16:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
else:
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
elif dtype == tf.float32: elif dtype == tf.float32:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
tf.keras.mixed_precision.set_global_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
else: else:
raise ValueError('Unexpected dtype: %s' % dtype) raise ValueError('Unexpected dtype: %s' % dtype)
...@@ -108,6 +108,7 @@ def get_activation(identifier, use_keras_layer=False): ...@@ -108,6 +108,7 @@ def get_activation(identifier, use_keras_layer=False):
"linear": "linear", "linear": "linear",
"identity": "linear", "identity": "linear",
"swish": "swish", "swish": "swish",
"sigmoid": "sigmoid",
"relu6": tf.nn.relu6, "relu6": tf.nn.relu6,
} }
if identifier in keras_layer_allowlist: if identifier in keras_layer_allowlist:
......
...@@ -46,6 +46,8 @@ class BertEncoderConfig(hyperparams.Config): ...@@ -46,6 +46,8 @@ class BertEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
output_range: Optional[int] = None output_range: Optional[int] = None
return_all_encoder_outputs: bool = False return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -132,6 +134,8 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -132,6 +134,8 @@ class BigBirdEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 4096 max_position_embeddings: int = 4096
num_rand_blocks: int = 3 num_rand_blocks: int = 3
block_size: int = 64 block_size: int = 64
...@@ -152,6 +156,8 @@ class KernelEncoderConfig(hyperparams.Config): ...@@ -152,6 +156,8 @@ class KernelEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
...@@ -161,6 +167,7 @@ class KernelEncoderConfig(hyperparams.Config): ...@@ -161,6 +167,7 @@ class KernelEncoderConfig(hyperparams.Config):
redraw: bool = False redraw: bool = False
is_short_seq: bool = False is_short_seq: bool = False
begin_kernel: int = 0 begin_kernel: int = 0
scale: Optional[float] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -339,6 +346,7 @@ def build_encoder(config: EncoderConfig, ...@@ -339,6 +346,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
attention_cls=layers.BigBirdAttention, attention_cls=layers.BigBirdAttention,
...@@ -377,6 +385,7 @@ def build_encoder(config: EncoderConfig, ...@@ -377,6 +385,7 @@ def build_encoder(config: EncoderConfig,
redraw=encoder_cfg.redraw, redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq, is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel, begin_kernel=encoder_cfg.begin_kernel,
scale=encoder_cfg.scale,
) )
hidden_cfg = dict( hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
...@@ -385,6 +394,7 @@ def build_encoder(config: EncoderConfig, ...@@ -385,6 +394,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention, attention_cls=layers.KernelAttention,
...@@ -445,4 +455,5 @@ def build_encoder(config: EncoderConfig, ...@@ -445,4 +455,5 @@ def build_encoder(config: EncoderConfig,
embedding_width=encoder_cfg.embedding_size, embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer, embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True) dict_outputs=True,
norm_first=encoder_cfg.norm_first)
This diff is collapsed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
import os
import tempfile
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
def decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
return tf.io.parse_single_example(record, name_to_features)
class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(BertClassifierLibTest, self).setUp()
self.model_dir = self.get_temp_dir()
self.processors = {
"CB": classifier_data_lib.CBProcessor,
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
"BOOLQ": classifier_data_lib.BoolQProcessor,
"WIC": classifier_data_lib.WiCProcessor,
}
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
self.tokenizer = tokenization.FullTokenizer(vocab_file)
@parameterized.parameters(
{"task_type": "CB"},
{"task_type": "BOOLQ"},
{"task_type": "SUPERGLUE-RTE"},
{"task_type": "WIC"},
)
def test_generate_dataset_from_tfds_processor(self, task_type):
with tfds.testing.mock_data(num_examples=5):
output_path = os.path.join(self.model_dir, task_type)
processor = self.processors[task_type]()
classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
self.tokenizer,
train_data_output_path=output_path,
eval_data_output_path=output_path,
test_data_output_path=output_path)
files = tf.io.gfile.glob(output_path)
self.assertNotEmpty(files)
train_dataset = tf.data.TFRecordDataset(output_path)
seq_length = 128
label_type = tf.int64
name_to_features = {
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([], label_type),
}
train_dataset = train_dataset.map(
lambda record: decode_record(record, name_to_features))
# If data is retrieved without error, then all requirements
# including data type/shapes are met.
_ = next(iter(train_dataset))
if __name__ == "__main__":
tf.test.main()
...@@ -50,7 +50,7 @@ flags.DEFINE_enum( ...@@ -50,7 +50,7 @@ flags.DEFINE_enum(
"classification_task_name", "MNLI", [ "classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ" "AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
], "The name of the task to train BERT classifier. The " ], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -173,6 +173,24 @@ flags.DEFINE_string( ...@@ -173,6 +173,24 @@ flags.DEFINE_string(
def generate_classifier_dataset(): def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data.""" """Generates classifier dataset and returns input meta data."""
if FLAGS.classification_task_name in [
"COLA",
"WNLI",
"SST-2",
"MRPC",
"QQP",
"STS-B",
"MNLI",
"QNLI",
"RTE",
"AX",
"SUPERGLUE-RTE",
"CB",
"BoolQ",
"WIC",
]:
assert not FLAGS.input_data_dir or FLAGS.tfds_params
else:
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
FLAGS.tfds_params) FLAGS.tfds_params)
...@@ -248,6 +266,8 @@ def generate_classifier_dataset(): ...@@ -248,6 +266,8 @@ def generate_classifier_dataset():
classifier_data_lib.CBProcessor, classifier_data_lib.CBProcessor,
"boolq": "boolq":
classifier_data_lib.BoolQProcessor, classifier_data_lib.BoolQProcessor,
"wic":
classifier_data_lib.WnliProcessor,
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
......
...@@ -60,8 +60,8 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -60,8 +60,8 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
else: else:
self._label_name_mapping = dict() self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor): def name_to_features_spec(self):
"""Decodes a serialized tf.Example.""" """Defines features to decode. Subclass may override to append features."""
label_type = LABEL_TYPES_MAP[self._params.label_type] label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
...@@ -72,7 +72,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -72,7 +72,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
example = tf.io.parse_single_example(record, self.name_to_features_spec())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32. # So cast all int64 to int32.
...@@ -86,20 +90,23 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -86,20 +90,23 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def _parse(self, record: Mapping[str, tf.Tensor]): def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x = { key_mapping = {
'input_word_ids': record['input_ids'], 'input_ids': 'input_word_ids',
'input_mask': record['input_mask'], 'input_mask': 'input_mask',
'input_type_ids': record['segment_ids'] 'segment_ids': 'input_type_ids'
} }
if self._include_example_id: ret = {}
x['example_id'] = record['example_id'] for record_key in record:
if record_key in key_mapping:
x[self._label_field] = record[self._label_field] ret[key_mapping[record_key]] = record[record_key]
else:
ret[record_key] = record[record_key]
if self._label_field in self._label_name_mapping: if self._label_field in self._label_name_mapping:
x[self._label_name_mapping[self._label_field]] = record[self._label_field] ret[self._label_name_mapping[self._label_field]] = record[
self._label_field]
return x return ret
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
...@@ -215,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -215,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess.""" """Berts preprocess."""
segments = [record[x] for x in self._text_fields] segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id: for key in record:
model_inputs['example_id'] = record['example_id'] if key not in self._text_fields:
model_inputs[self._label_field] = record[self._label_field] model_inputs[key] = record[key]
return model_inputs return model_inputs
def _decode(self, record: tf.Tensor): def name_to_features_spec(self):
"""Decodes a serialized tf.Example."""
name_to_features = {} name_to_features = {}
for text_field in self._text_fields: for text_field in self._text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string) name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
...@@ -230,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -230,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type) name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
example = tf.io.parse_single_example(record, self.name_to_features_spec())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32. # So cast all int64 to int32.
for name in example: for name in example:
......
...@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
...@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
...@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
......
...@@ -69,6 +69,9 @@ class BertEncoder(tf.keras.Model): ...@@ -69,6 +69,9 @@ class BertEncoder(tf.keras.Model):
smaller than 'hidden_size'). smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs. generate embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
""" """
def __init__( def __init__(
...@@ -87,6 +90,7 @@ class BertEncoder(tf.keras.Model): ...@@ -87,6 +90,7 @@ class BertEncoder(tf.keras.Model):
output_range=None, output_range=None,
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
norm_first=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -162,6 +166,7 @@ class BertEncoder(tf.keras.Model): ...@@ -162,6 +166,7 @@ class BertEncoder(tf.keras.Model):
inner_activation=inner_activation, inner_activation=inner_activation,
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
...@@ -211,6 +216,7 @@ class BertEncoder(tf.keras.Model): ...@@ -211,6 +216,7 @@ class BertEncoder(tf.keras.Model):
'output_range': output_range, 'output_range': output_range,
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first,
} }
# We are storing the config dict as a namedtuple here to ensure checkpoint # We are storing the config dict as a namedtuple here to ensure checkpoint
......
...@@ -205,7 +205,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -205,7 +205,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
initializer="glorot_uniform", initializer="glorot_uniform",
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
embedding_layer=None) embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize( expected_config["inner_activation"] = tf.keras.activations.serialize(
......
...@@ -48,12 +48,12 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -48,12 +48,12 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
test_layer = position_embedding.PositionEmbedding( test_layer = position_embedding.PositionEmbedding(
max_length=sequence_length, seq_axis=2) max_length=sequence_length, seq_axis=2)
width = 30 width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width, width)) input_tensor = tf.keras.Input(shape=(width, sequence_length, width))
output_tensor = test_layer(input_tensor) output_tensor = test_layer(input_tensor)
# When using static positional embedding shapes, the output is expected # When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch. # to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, width, width] expected_output_shape = [None, width, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32. # The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype) self.assertEqual(tf.float32, output_tensor.dtype)
......
...@@ -249,7 +249,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -249,7 +249,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
attention. attention.
Returns: Returns:
An ouput tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
""" """
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 2: if len(inputs) == 2:
......
...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None): ...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix) return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, is_query, f, h, def _generalized_kernel(x, projection_matrix, f, h):
data_normalizer_fn=None):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS. """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args: Args:
x: The feature being transformed with shape [B, T, N ,H]. x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections. M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x. f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None. transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns: Returns:
Transformed feature. Transformed feature.
""" """
# No asymmetric operations.
del is_query
if data_normalizer_fn is not None:
x = data_normalizer_fn(x)
if projection_matrix is None: if projection_matrix is None:
return h(x) * f(x) return h(x) * f(x)
...@@ -139,26 +129,18 @@ _TRANSFORM_MAP = { ...@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum( -0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)), tf.math.square(x), axis=-1, keepdims=True)),),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"expmod": "expmod":
functools.partial( functools.partial(
_generalized_kernel, _generalized_kernel,
# Avoid exp explosion by shifting. # Avoid exp explosion by shifting.
f=lambda x: tf.math.exp( f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))), tf.cast(tf.shape(x)[-1], tf.float32))),
data_normalizer_fn=lambda x: x / ),
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))), "identity":
"l2": functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
functools.partial(
_generalized_kernel,
f=lambda x: x,
h=lambda x: tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32)),
data_normalizer_fn=lambda x: x),
"identity": lambda x, projection_matrix, is_query: x
} }
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
...@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794) (https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2 - exp (Lemma 1, positive), relu
- random/deterministic projection - random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw=False, redraw=False,
is_short_seq=False, is_short_seq=False,
begin_kernel=0, begin_kernel=0,
scale=None,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
Args: Args:
feature_transform: A non-linear transform of the keys and quries. feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod", Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose "identity".
feature_transform as "l2".
num_random_features: Number of random features to be used for projection. num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform. if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the seed: The seed to begin drawing random features. Once the seed is set, the
...@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option). (default option).
begin_kernel: Apply kernel_attention after this sequence id and apply begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this. softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer. **kwargs: The same arguments `MultiHeadAttention` layer.
""" """
if feature_transform not in _TRANSFORM_MAP: if feature_transform not in _TRANSFORM_MAP:
...@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference # 1. inference
# 2. no redraw # 2. no redraw
self._seed = seed self._seed = seed
super().__init__(**kwargs) super().__init__(**kwargs)
if scale is None:
self._scale = 1.0 / math.sqrt(float(self._key_dim))
else:
self._scale = scale
self._projection_matrix = None self._projection_matrix = None
if num_random_features > 0: if num_random_features > 0:
self._projection_matrix = create_projection_matrix( self._projection_matrix = create_projection_matrix(
...@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns: Returns:
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
""" """
projection_matrix = None projection_matrix = None
if self._num_random_features > 0: if self._num_random_features > 0:
if self._redraw and training: if self._redraw and training:
...@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else: else:
projection_matrix = self._projection_matrix projection_matrix = self._projection_matrix
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix, False) if is_short_seq:
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix, True) # Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if attention_mask is not None: if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask) key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
...@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key) attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2) attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value) kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / ( denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) + tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER) _NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator) attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) super()._build_from_signature(query=query, value=value, key=key)
...@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw": self._redraw, "redraw": self._redraw,
"is_short_seq": self._is_short_seq, "is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel, "begin_kernel": self._begin_kernel,
"scale": self._scale,
} }
base_config = super().get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'l2'] _FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_REDRAW = [True, False] _REDRAW = [True, False]
_TRAINING = [True, False] _TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False] _IS_SHORT_SEQ = [True, False]
......
...@@ -33,121 +33,6 @@ def _check_if_tf_text_installed(): ...@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'.") "'tensorflow-text-nightly'.")
def _iterative_vectorized_fair_share(capacity: tf.Tensor,
limit: Union[int, tf.Tensor]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit = limit // capacity.shape[0]
limit_mask = tf.ones(capacity.shape, dtype=tf.int64) * per_seg_limit
lower_bound = tf.minimum(capacity, limit_mask)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum = limit - tf.math.reduce_sum(lower_bound, axis=0)
remaining_cap_mat = capacity - lower_bound
new_cap = lower_bound + remaining_cap_mat * tf.cast(
tf.math.reduce_sum(remaining_cap_mat, axis=0) <= remaining_cap_sum,
tf.int64)
# Process iteratively. This step is O(#segments), see analysis below.
while True:
remaining_limit = limit - tf.math.reduce_sum(new_cap, axis=0)
remaining_cap = capacity - new_cap
masked_remaining_slots = tf.cast(remaining_cap > 0, tf.int64)
remaining_cap_col_slots = tf.reduce_sum(masked_remaining_slots, axis=0)
masked_remaining_limit = tf.cast(remaining_cap_col_slots > 0,
tf.int64) * remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit = masked_remaining_limit // (
tf.cast(remaining_cap_col_slots <= 0, tf.int64) +
remaining_cap_col_slots) # +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if tf.math.reduce_sum(per_seg_limit) > 0:
remaining_slots_mat = tf.cast(remaining_cap > 0, tf.int64)
new_cap = new_cap + remaining_slots_mat * per_seg_limit
else:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask = tf.cast(
(tf.cumsum(masked_remaining_slots, axis=0) <= masked_remaining_limit)
& (masked_remaining_slots > 0), tf.int64)
new_cap = new_cap + new_remained_assignment_mask
break
return new_cap
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
) -> Union[tf.RaggedTensor, List[tf.RaggedTensor]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if not isinstance(inputs, (list, tuple)):
return round_robin_truncate_inputs([inputs], limit)[0]
limit = tf.cast(limit, tf.int64)
if not all(rt.shape.rank == 2 for rt in inputs):
raise ValueError("All inputs must have shape [batch_size, (items)]")
if len(inputs) == 1:
return [_truncate_row_lengths(inputs[0], limit)]
elif len(inputs) == 2:
size_a, size_b = [rt.row_lengths() for rt in inputs]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half = limit // 2
ceil_half = limit - floor_half
quota_a = tf.minimum(size_a, ceil_half + tf.nn.relu(floor_half - size_b))
quota_b = tf.minimum(size_b, floor_half + tf.nn.relu(ceil_half - size_a))
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity = tf.stack([rt.row_lengths() for rt in inputs]) # #Segments x B
new_capacity = _iterative_vectorized_fair_share(capacity, limit)
return [
_truncate_row_lengths(inputs[i], new_capacity[i])
for i in range(capacity.shape[0])
]
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor, def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
new_lengths: tf.Tensor) -> tf.RaggedTensor: new_lengths: tf.Tensor) -> tf.RaggedTensor:
"""Truncates the rows of `ragged_tensor` to the given row lengths.""" """Truncates the rows of `ragged_tensor` to the given row lengths."""
...@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
# fall back to some ad-hoc truncation. # fall back to some ad-hoc truncation.
num_special_tokens = len(inputs) + 1 num_special_tokens = len(inputs) + 1
if truncator == "round_robin": if truncator == "round_robin":
trimmed_segments = round_robin_truncate_inputs( trimmed_segments = text.RoundRobinTrimmer(seq_length -
inputs, seq_length - num_special_tokens) num_special_tokens).trim(inputs)
elif truncator == "waterfall": elif truncator == "waterfall":
trimmed_segments = text.WaterfallTrimmer( trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(inputs) seq_length - num_special_tokens).trim(inputs)
......
...@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer ...@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def _test_input(self, start, lengths):
return tf.ragged.constant([[start + 10 * j + i
for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
def test_single_segment(self):
# Single segment.
single_input = self._test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35], # Truncated.
])
self.assertAllEqual(
expected_single_output,
text_layers.round_robin_truncate_inputs(single_input, limit=5))
# Test wrapping in a singleton list.
actual_single_list_output = text_layers.round_robin_truncate_inputs(
[single_input], limit=5)
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
def test_two_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153], # Truncated.
[161, 162, 163], # Truncated.
])
expected_b = tf.ragged.constant(
[[211],
[221, 222, 223],
[231, 232, 233], # Truncated.
[241, 242],
[251, 252],
[261, 262], # Truncated.
])
actual_a, actual_b = text_layers.round_robin_truncate_inputs(
[input_a, input_b], limit=5)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
def test_three_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5, 1])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5, 8])
input_c = self._test_input(311, [1, 3, 4, 2, 2, 5, 10])
seg_limit = 8
expected_a = tf.ragged.constant([
[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153, 154],
[161, 162, 163], # Truncated
[171]
])
expected_b = tf.ragged.constant([
[211],
[221, 222, 223],
[231, 232, 233], # Truncated
[241, 242],
[251, 252],
[261, 262, 263], # Truncated
[271, 272, 273, 274] # Truncated
])
expected_c = tf.ragged.constant([
[311],
[321, 322, 323],
[331, 332, 333], # Truncated
[341, 342],
[351, 352],
[361, 362], # Truncated
[371, 372, 373] # Truncated
])
actual_a, actual_b, actual_c = text_layers.round_robin_truncate_inputs(
[input_a, input_b, input_c], limit=seg_limit)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
self.assertAllEqual(expected_c, actual_c)
input_cap = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [input_a, input_b, input_c]]),
axis=0)
per_example_usage = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [actual_a, actual_b, actual_c]]),
axis=0)
self.assertTrue(all(per_example_usage <= tf.minimum(seg_limit, input_cap)))
# This test covers the in-process behavior of a BertTokenizer layer. # This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference), # For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py. # see nlp/tools/export_tfhub_lib_test.py.
......
...@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
def _embedding_linear(self, embedding_matrix, x): def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights.""" """Uses embeddings as linear transformation weights."""
embedding_matrix = tf.cast(embedding_matrix, dtype=self.compute_dtype)
x = tf.cast(x, dtype=self.compute_dtype)
batch_size = tf.shape(x)[0] batch_size = tf.shape(x)[0]
length = tf.shape(x)[1] length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2] hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0] vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size]) x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True) logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size]) return tf.reshape(logits, [batch_size, length, vocab_size])
......
...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings. generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
""" """
def __init__(self, def __init__(self,
...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
dict_outputs=False, dict_outputs=False,
norm_first=False,
**kwargs): **kwargs):
# b/164516224 # b/164516224
...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer=initializer, initializer=initializer,
output_range=output_range, output_range=output_range,
embedding_width=embedding_width, embedding_width=embedding_width,
embedding_layer=embedding_layer) embedding_layer=embedding_layer,
norm_first=norm_first)
self._embedding_layer_instance = embedding_layer self._embedding_layer_instance = embedding_layer
......
...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
dict_outputs=True, dict_outputs=True,
embedding_layer=None) embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment