Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team * @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
/official/ @rachellj218 @saberkun @jaeyounkim /official/ @rachellj218 @saberkun @jaeyounkim
/official/nlp/ @saberkun @lehougoogle @rachellj218 @jaeyounkim /official/nlp/ @saberkun @lehougoogle @rachellj218 @jaeyounkim
/official/recommendation/ranking/ @gagika
/official/vision/ @xianzhidu @yeqingli @arashwan @saberkun @rachellj218 @jaeyounkim /official/vision/ @xianzhidu @yeqingli @arashwan @saberkun @rachellj218 @jaeyounkim
/official/vision/beta/projects/assemblenet/ @mryoo /official/vision/beta/projects/assemblenet/ @mryoo
/official/vision/beta/projects/deepmac_maskrcnn/ @vighneshbirodkar /official/vision/beta/projects/deepmac_maskrcnn/ @vighneshbirodkar
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -38,7 +38,10 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -38,7 +38,10 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs. # Special keys in train/validate step returned logs.
loss = "loss" loss = "loss"
def __init__(self, params, logging_dir: str = None, name: str = None): def __init__(self,
params,
logging_dir: Optional[str] = None,
name: Optional[str] = None):
"""Task initialization. """Task initialization.
Args: Args:
...@@ -294,11 +297,38 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -294,11 +297,38 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs): def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step.""" """Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass pass
def reduce_aggregated_logs(self, def reduce_aggregated_logs(self,
aggregated_logs, aggregated_logs,
global_step: Optional[tf.Tensor] = None): global_step: Optional[tf.Tensor] = None):
"""Optional reduce of aggregated logs over validation steps.""" """Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return {} return {}
...@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer): ...@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer):
self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean( self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32) "validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics( self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics training=True) + model_metrics
self._validation_metrics = self.task.build_metrics( self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics training=False) + model_metrics
self.init_async() self.init_async()
......
...@@ -28,4 +28,4 @@ def hard_sigmoid(features): ...@@ -28,4 +28,4 @@ def hard_sigmoid(features):
The activation value. The activation value.
""" """
features = tf.convert_to_tensor(features) features = tf.convert_to_tensor(features)
return tf.nn.relu6(features + tf.constant(3.)) * 0.16667 return tf.nn.relu6(features + tf.cast(3., features.dtype)) * 0.16667
...@@ -52,7 +52,8 @@ def hard_swish(features): ...@@ -52,7 +52,8 @@ def hard_swish(features):
The activation value. The activation value.
""" """
features = tf.convert_to_tensor(features) features = tf.convert_to_tensor(features)
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.) fdtype = features.dtype
return features * tf.nn.relu6(features + tf.cast(3., fdtype)) * (1. / 6.)
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
......
...@@ -41,6 +41,7 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -41,6 +41,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop: rmsprop optimizer. rmsprop: rmsprop optimizer.
lars: lars optimizer. lars: lars optimizer.
adagrad: adagrad optimizer. adagrad: adagrad optimizer.
slide: slide optimizer.
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
...@@ -50,6 +51,7 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -50,6 +51,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig() rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig() lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig() adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
slide: opt_cfg.SLIDEConfig = opt_cfg.SLIDEConfig()
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -226,3 +226,24 @@ class LARSConfig(BaseOptimizerConfig): ...@@ -226,3 +226,24 @@ class LARSConfig(BaseOptimizerConfig):
classic_momentum: bool = True classic_momentum: bool = True
exclude_from_weight_decay: Optional[List[str]] = None exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None exclude_from_layer_adaptation: Optional[List[str]] = None
@dataclasses.dataclass
class SLIDEConfig(BaseOptimizerConfig):
"""Configuration for SLIDE optimizer.
Details coming soon.
"""
name: str = "SLIDE"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
weight_decay_rate: float = 0.0
weight_decay_type: str = "inner"
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
include_in_sparse_layer_adaptation: Optional[List[str]] = None
sparse_layer_learning_rate: float = 0.1
do_gradient_rescaling: bool = True
norm_type: str = "layer"
ratio_clip_norm: float = 1e5
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Exponential moving average optimizer.""" """Exponential moving average optimizer."""
from typing import Text, List from typing import List, Optional, Text
import tensorflow as tf import tensorflow as tf
...@@ -106,7 +106,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -106,7 +106,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def _create_slots(self, var_list): def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
def apply_gradients(self, grads_and_vars, name: Text = None): def apply_gradients(self, grads_and_vars, name: Optional[Text] = None):
result = self._optimizer.apply_gradients(grads_and_vars, name) result = self._optimizer.apply_gradients(grads_and_vars, name)
self.update_average(self.iterations) self.update_average(self.iterations)
return result return result
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Callable, Union from typing import Callable, Optional, Union
import gin import gin
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import slide_optimizer
from official.modeling.optimization import ema_optimizer from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lars_optimizer from official.modeling.optimization import lars_optimizer
from official.modeling.optimization import lr_schedule from official.modeling.optimization import lr_schedule
...@@ -33,6 +34,7 @@ OPTIMIZERS_CLS = { ...@@ -33,6 +34,7 @@ OPTIMIZERS_CLS = {
'rmsprop': tf.keras.optimizers.RMSprop, 'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS, 'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad, 'adagrad': tf.keras.optimizers.Adagrad,
'slide': slide_optimizer.SLIDE
} }
LR_CLS = { LR_CLS = {
...@@ -134,8 +136,8 @@ class OptimizerFactory: ...@@ -134,8 +136,8 @@ class OptimizerFactory:
def build_optimizer( def build_optimizer(
self, self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float], lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
postprocessor: Callable[[tf.keras.optimizers.Optimizer], postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer] = None): tf.keras.optimizers.Optimizer]] = None):
"""Build optimizer. """Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds Builds optimizer from config. It takes learning rate as input, and builds
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SLIDE optimizer.
A new optimizer that will be open sourced soon.
"""
SLIDE = "Unimplemented"
...@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer):
checkpoint_interval=checkpoint_interval, checkpoint_interval=checkpoint_interval,
) )
# Make sure we export the last checkpoint.
last_checkpoint = (
self.global_step.numpy() == self._config.trainer.train_steps)
checkpoint_path = self._export_ckpt_manager.save( checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(), checkpoint_number=self.global_step.numpy(),
check_interval=True) check_interval=not last_checkpoint)
if checkpoint_path: if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path) logging.info('Checkpoints exported: %s.', checkpoint_path)
...@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor): ...@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor):
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(ColaProcessor, self).__init__(process_text_fn)
self.dataset = tfds.load("glue/cola", try_gcs=True)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor): ...@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "COLA" return "COLA"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = self.dataset[set_type].as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": label = "0"
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(example["sentence"])
label = "0" if set_type != "test":
else: label = str(example["label"])
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
return examples return examples
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
import functools import functools
from typing import List, Mapping, Optional from typing import List, Mapping, Optional, Tuple
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -40,6 +40,10 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -40,6 +40,10 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int' label_type: str = 'int'
# Whether to include the example id number. # Whether to include the example id number.
include_example_id: bool = False include_example_id: bool = False
label_field: str = 'label_ids'
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -50,6 +54,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -50,6 +54,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_example_id = params.include_example_id self._include_example_id = params.include_example_id
self._label_field = params.label_field
if params.label_name:
self._label_name_mapping = dict([params.label_name])
else:
self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -58,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -58,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type), self._label_field: tf.io.FixedLenFeature([], label_type),
} }
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
...@@ -85,8 +94,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -85,8 +94,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id: if self._include_example_id:
x['example_id'] = record['example_id'] x['example_id'] = record['example_id']
y = record['label_ids'] x[self._label_field] = record[self._label_field]
return (x, y)
if self._label_field in self._label_name_mapping:
x[self._label_name_mapping[self._label_field]] = record[self._label_field]
return x
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
...@@ -204,8 +217,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -204,8 +217,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id: if self._include_example_id:
model_inputs['example_id'] = record['example_id'] model_inputs['example_id'] = record['example_id']
y = record[self._label_field] model_inputs[self._label_field] = record[self._label_field]
return model_inputs, y return model_inputs
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
......
...@@ -132,14 +132,40 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -132,14 +132,40 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
global_batch_size=batch_size, global_batch_size=batch_size,
label_type=label_type) label_type=label_type)
dataset = loader.SentencePredictionDataLoader(data_config).load() dataset = loader.SentencePredictionDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'], self.assertCountEqual(
features.keys()) ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type) self.assertEqual(features['label_ids'].dtype, expected_label_type)
def test_load_dataset_with_label_mapping(self):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10
seq_length = 128
_create_fake_preprocessed_dataset(input_path, seq_length, 'int')
data_config = loader.SentencePredictionDataConfig(
input_path=input_path,
seq_length=seq_length,
global_batch_size=batch_size,
label_type='int',
label_name=('label_ids', 'next_sentence_labels'))
dataset = loader.SentencePredictionDataLoader(data_config).load()
features = next(iter(dataset))
self.assertCountEqual([
'input_word_ids', 'input_mask', 'input_type_ids',
'next_sentence_labels', 'label_ids'
], features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, tf.int32)
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...@@ -170,13 +196,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -170,13 +196,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
lower_case=lower_case, lower_case=lower_case,
vocab_file=vocab_file_path) vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], label_field = data_config.label_field
features.keys()) self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds): def test_python_sentencepiece_preprocessing(self, use_tfds):
...@@ -203,13 +231,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -203,13 +231,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=sp_model_file_path, vocab_file=sp_model_file_path,
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], label_field = data_config.label_field
features.keys()) self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds): def test_saved_model_preprocessing(self, use_tfds):
...@@ -236,13 +266,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -236,13 +266,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
label_type='int' if use_tfds else 'float', label_type='int' if use_tfds else 'float',
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], label_field = data_config.label_field
features.keys()) self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""XLNet models.""" """XLNet models."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union from typing import Any, Mapping, Optional, Union
import tensorflow as tf import tensorflow as tf
...@@ -99,7 +99,7 @@ class XLNetPretrainer(tf.keras.Model): ...@@ -99,7 +99,7 @@ class XLNetPretrainer(tf.keras.Model):
network: Union[tf.keras.layers.Layer, tf.keras.Model], network: Union[tf.keras.layers.Layer, tf.keras.Model],
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
name: str = None, name: Optional[str] = None,
**kwargs): **kwargs):
super().__init__(name=name, **kwargs) super().__init__(name=name, **kwargs)
self._config = { self._config = {
......
...@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def _continue_search(self, state) -> tf.Tensor: def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX] i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length) # Have we reached max decoding length?
not_at_end = tf.less(i, self.max_decode_length)
# Have all sampled sequences reached an EOS?
all_has_eos = tf.reduce_all(
state[decoding_module.StateKeys.FINISHED_FLAGS],
axis=None,
name="search_finish_cond")
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos))
def _finished_flags(self, topk_ids, state) -> tf.Tensor: def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id) new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or( new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS]) new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags return new_finished_flags
...@@ -22,7 +22,7 @@ modeling library: ...@@ -22,7 +22,7 @@ modeling library:
* [mobile_bert_encoder.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py) * [mobile_bert_encoder.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py)
contains `MobileBERTEncoder` implementation. contains `MobileBERTEncoder` implementation.
* [mobile_bert_layers.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py) * [mobile_bert_layers.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py)
contains `MobileBertEmbedding`, `MobileBertMaskedLM` and `MobileBertMaskedLM` contains `MobileBertEmbedding`, `MobileBertTransformer` and `MobileBertMaskedLM`
implementation. implementation.
## Pre-trained Models ## Pre-trained Models
......
# TEAMS (Training ELECTRA Augmented with Multi-word Selection)
**Note:** This project is working in progress and please stay tuned.
TEAMS is a text encoder pre-training method that simultaneously learns a
generator and a discriminator using multi-task learning. We propose a new
pre-training task, multi-word selection, and combine it with previous
pre-training tasks for efficient encoder pre-training. We also develop two
techniques, attention-based task-specific heads and partial layer sharing,
to further improve pre-training effectiveness.
Our academic paper [[1]](#1) which describes TEAMS in detail can be found here:
https://arxiv.org/abs/2106.00139.
## References
<a id="1">[1]</a>
Jiaming Shen, Jialu Liu, Tianqi Liu, Cong Yu and Jiawei Han, "Training ELECTRA
Augmented with Multi-word Selection", Findings of the Association for
Computational Linguistics: ACL 2021.
...@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if params.metric_type not in METRIC_TYPES: if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type)) raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type self.metric_type = params.metric_type
if hasattr(params.train_data, 'label_field'):
self.label_field = params.train_data.label_field
else:
self.label_field = 'label_ids'
def build_model(self): def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint: if self.task_config.hub_module_url and self.task_config.init_checkpoint:
...@@ -95,11 +99,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -95,11 +99,12 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler) use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels[self.label_field]
if self.task_config.model.num_classes == 1: if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs) loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else: else:
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, tf.cast(model_outputs, tf.float32), from_logits=True) label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -120,7 +125,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -120,7 +125,8 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32) y = tf.zeros((1,), dtype=tf.float32)
else: else:
y = tf.zeros((1, 1), dtype=tf.int32) y = tf.zeros((1, 1), dtype=tf.int32)
return x, y x[self.label_field] = y
return x
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
...@@ -142,16 +148,16 @@ class SentencePredictionTask(base_task.Task): ...@@ -142,16 +148,16 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs) compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy': if self.metric_type == 'accuracy':
return super(SentencePredictionTask, return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics) self).validation_step(inputs, model, metrics)
features, labels = inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
...@@ -161,12 +167,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -161,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension. 'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1), tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels': 'labels':
labels, labels[self.label_field],
}) })
if self.metric_type == 'pearson_spearman_corr': if self.metric_type == 'pearson_spearman_corr':
logs.update({ logs.update({
'sentence_prediction': outputs, 'sentence_prediction': outputs,
'labels': labels, 'labels': labels[self.label_field],
}) })
return logs return logs
...@@ -206,10 +212,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -206,10 +212,10 @@ class SentencePredictionTask(base_task.Task):
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file: if not ckpt_dir_or_file:
return return
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'], 'encoder': model.checkpoint_items['encoder'],
...@@ -250,7 +256,7 @@ def predict(task: SentencePredictionTask, ...@@ -250,7 +256,7 @@ def predict(task: SentencePredictionTask,
def predict_step(inputs): def predict_step(inputs):
"""Replicated prediction calculation.""" """Replicated prediction calculation."""
x, _ = inputs x = inputs
example_id = x.pop('example_id') example_id = x.pop('example_id')
outputs = task.inference_step(x, model) outputs = task.inference_step(x, model)
return dict(example_id=example_id, predictions=outputs) return dict(example_id=example_id, predictions=outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment