Commit e7667f6f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into context_tf2

parents 974c463e 709a6617
...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation ### Segmentation
| Model | Paper | Features | Maintainer | | Model | Paper | Features | Maintainer |
......
...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS ...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark): class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing.""" """Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__( super(CtlBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=self.default_flags, default_flags=self.default_flags,
flag_methods=self.flag_methods) flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self, def _report_benchmark(self,
stats, stats,
...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark): class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags] flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
...@@ -381,12 +387,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -381,12 +387,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self): def benchmark_2x2_tpu_bf16(self):
self._setup() self._setup()
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler') @owner_utils.Owner('tf-graph-compiler')
...@@ -396,6 +404,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -396,6 +404,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self): def benchmark_4x4_tpu_bf16(self):
...@@ -403,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -403,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler') @owner_utils.Owner('tf-graph-compiler')
...@@ -412,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -412,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -439,7 +450,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase): ...@@ -439,7 +450,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__( super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...@@ -454,7 +465,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): ...@@ -454,7 +465,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__( super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -112,6 +112,8 @@ class RuntimeConfig(base_config.Config): ...@@ -112,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance. persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
""" """
distribution_strategy: str = "mirrored" distribution_strategy: str = "mirrored"
enable_xla: bool = False enable_xla: bool = False
...@@ -183,14 +185,19 @@ class TrainerConfig(base_config.Config): ...@@ -183,14 +185,19 @@ class TrainerConfig(base_config.Config):
validation_interval: number of training steps to run between evaluations. validation_interval: number of training steps to run between evaluations.
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True train_tf_while_loop: bool = True
train_tf_function: bool = True train_tf_function: bool = True
eval_tf_function: bool = True eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None continuous_eval_timeout: Optional[int] = None
# Train/Eval routines.
train_steps: int = 0 train_steps: int = 0
validation_steps: Optional[int] = None validation_steps: Optional[int] = None
validation_interval: int = 1000 validation_interval: int = 1000
......
...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config): ...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length: int = 512 sequence_length: int = 512
num_classes: int = 2 num_classes: int = 2
discriminator_loss_weight: float = 50.0 discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = ( generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = ( discriminator_encoder: encoders.TransformerEncoderConfig = (
...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg( ...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config.""" """Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder discriminator_encoder_cfg = config.discriminator_encoder
if generator_network is None: # Copy discriminator's embeddings to generator for easier model serialization.
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
if discriminator_network is None: if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg( discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg) discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer( return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network, generator_network=generator_network,
discriminator_network=discriminator_network, discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size, vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes, num_classes=config.num_classes,
sequence_length=config.sequence_length, sequence_length=config.sequence_length,
last_hidden_dim=config.generator_encoder.hidden_size,
num_token_predictions=config.num_masked_tokens, num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation( mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation), generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range), stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads),
disallow_correct=config.disallow_correct)
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
Includes configurations and instantiation methods. Includes configurations and instantiation methods.
""" """
from typing import Optional
import dataclasses import dataclasses
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
@gin.configurable def instantiate_encoder_from_cfg(
def instantiate_encoder_from_cfg(config: TransformerEncoderConfig, config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder): encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold": if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict( embedding_cfg = dict(
...@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig, ...@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length=config.max_position_embeddings, max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network return encoder_network
...@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model): ...@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated. argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model): ...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. and a classification output.
......
...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805). for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes` instantiates a token classification network based on the passed `num_classes`
argument. argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side) model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives. that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments: Arguments:
generator_network: A transformer network for generator, this network should generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output. output a sequence output and an optional classification output.
...@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
...@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length, sequence_length,
last_hidden_dim,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length, 'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
...@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim, inner_dim=generator_network._config_dict['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=mlm_initializer,
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1, kernel_initializer=mlm_initializer)
...@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
if isinstance(disc_sequence_output, list): if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1] disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output) disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1) disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = { outputs = {
...@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens 'sampled_tokens': sampled_tokens
} }
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length, sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
......
...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base The default values for this object are taken from the ALBERT-Base
implementation described in the paper. implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
......
...@@ -29,6 +29,9 @@ class Classification(tf.keras.Model): ...@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
This network implements a simple classifier head based on a dense layer. If This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem. num_classes is one, it can be considered as a regression problem.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. If num_classes: The number of classes that this network should classify to. If
......
...@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model): ...@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
If the hidden_cls is not overridden, a default transformer layer will be If the hidden_cls is not overridden, a default transformer layer will be
instantiated. instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
pooled_output_dim: The dimension of pooled output. pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification pooler_layer_initializer: The initializer for the classification
......
...@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model): ...@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling. """Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer. This network implements a simple single-span labeler based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
......
...@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model): ...@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling. """TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer. This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
......
...@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model): ...@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding". Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.data import pretrain_dataloader
@dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig):
"""The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig(
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
])
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(ELECTRAPretrainConfig)
class ELECTRAPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self):
return electra.instantiate_pretrainer_from_cfg(
self.task_config.model)
def build_losses(self,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
# generator lm and (optional) nsp loss.
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_outputs'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
# discriminator replaced token detection (rtd) loss.
rtd_logits = model_outputs['disc_logits']
rtd_labels = tf.cast(model_outputs['disc_label'], tf.float32)
input_mask = tf.cast(labels['input_mask'], tf.float32)
rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=rtd_logits, labels=rtd_labels)
rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss)
rtd_denominator = tf.reduce_sum(input_mask)
rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator)
metrics['discriminator_loss'].update_state(rtd_loss)
total_loss = total_loss + \
self.task_config.model.discriminator_loss_weight * rtd_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
metrics['total_loss'].update_state(total_loss)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='discriminator_accuracy'),
]
if self.task_config.train_data.use_next_sentence_label:
metrics.append(
tf.keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
metrics.append(tf.keras.metrics.Mean(name='discriminator_loss'))
metrics.append(tf.keras.metrics.Mean(name='total_loss'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_outputs'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['sentence_outputs'])
if 'discriminator_accuracy' in metrics:
disc_logits_expanded = tf.expand_dims(model_outputs['disc_logits'], -1)
discrim_full_logits = tf.concat(
[-1.0 * disc_logits_expanded, disc_logits_expanded], -1)
metrics['discriminator_accuracy'].update_state(
model_outputs['disc_label'], discrim_full_logits,
labels['input_mask'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.electra_task."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase):
def test_task(self):
config = electra_task.ELECTRAPretrainConfig(
model=electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
sequence_length=128,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
...@@ -125,7 +125,7 @@ def run(flags_obj): ...@@ -125,7 +125,7 @@ def run(flags_obj):
per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
flags_obj) flags_obj)
if flags_obj.steps_per_loop is None: if not flags_obj.steps_per_loop:
steps_per_loop = per_epoch_steps steps_per_loop = per_epoch_steps
elif flags_obj.steps_per_loop > per_epoch_steps: elif flags_obj.steps_per_loop > per_epoch_steps:
steps_per_loop = per_epoch_steps steps_per_loop = per_epoch_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment