Commit 3ce2f61b authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into context_tf2

parents bb16d5ca 8e9296ff
...@@ -10,11 +10,13 @@ can take full advantage of TensorFlow for their research and product development ...@@ -10,11 +10,13 @@ can take full advantage of TensorFlow for their research and product development
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read | | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers | | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 | | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements) ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News | | Date | News |
|------|------| |------|------|
| July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) | | June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) | | June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released | | May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
...@@ -23,12 +25,6 @@ can take full advantage of TensorFlow for their research and product development ...@@ -23,12 +25,6 @@ can take full advantage of TensorFlow for their research and product development
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 | | May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) | | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions ## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation) [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
......
...@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build. ...@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models. The team is actively developing new models.
In the near future, we will add: In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models.
More members in Transformer family * State-of-the-art image classification models.
* State-of-the-art image classification models: * State-of-the-art objection detection and instance segmentation models.
EfficientNet, MnasNet, and variants
* State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents ## Table of Contents
......
...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS) params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params) strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
......
This diff is collapsed.
This diff is collapsed.
...@@ -59,7 +59,7 @@ class Task(tf.Module): ...@@ -59,7 +59,7 @@ class Task(tf.Module):
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir. checkpoint, saved under a directory other than the model_dir.
...@@ -71,7 +71,7 @@ class Task(tf.Module): ...@@ -71,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -135,7 +135,7 @@ class Task(tf.Module): ...@@ -135,7 +135,7 @@ class Task(tf.Module):
Args: Args:
labels: optional label tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -232,7 +232,7 @@ class Task(tf.Module): ...@@ -232,7 +232,7 @@ class Task(tf.Module):
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validation step.
With distribution strategies, this method runs on devices. With distribution strategies, this method runs on devices.
......
...@@ -171,6 +171,9 @@ class InputReader: ...@@ -171,6 +171,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised, as_supervised=self._tfds_as_supervised,
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
......
...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict): ...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config subconfig_type = Config
if k in cls.__annotations__: if k in cls.__annotations__:
# Directly Config subtype. # Directly Config subtype.
type_annotation = cls.__annotations__[k] type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)): issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k] subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else: else:
# Check if the field is a sequence of subtypes. # Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None)) field_type = getattr(type_annotation, '__origin__', type(None))
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Union
import dataclasses import dataclasses
...@@ -123,8 +124,8 @@ class RuntimeConfig(base_config.Config): ...@@ -123,8 +124,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
...@@ -172,23 +173,27 @@ class TrainerConfig(base_config.Config): ...@@ -172,23 +173,27 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval. eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop. steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary. summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely. checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_steps: int = 0 train_tf_while_loop: bool = True
validation_steps: Optional[int] = None train_tf_function: bool = True
validation_interval: int = 100 eval_tf_function: bool = True
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True train_steps: int = 0
train_tf_function: bool = True validation_steps: Optional[int] = None
eval_tf_function: bool = True validation_interval: int = 1000
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -20,6 +20,20 @@ import dataclasses ...@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass @dataclasses.dataclass
class StepwiseLrConfig(base_config.Config): class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay. """Configuration for stepwise learning rate decay.
......
...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig): ...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes: Attributes:
type: 'str', type of lr schedule to be used, on the of fields below. type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config. stepwise: stepwise learning rate config.
exponential: exponential learning rate config. exponential: exponential learning rate config.
polynomial: polynomial learning rate config. polynomial: polynomial learning rate config.
cosine: cosine learning rate config. cosine: cosine learning rate config.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig() stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig() exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig() polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config): ...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer. decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
""" """
name: str = "SGD" name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0 decay: float = 0.0
nesterov: bool = False nesterov: bool = False
momentum: float = 0.0 momentum: float = 0.0
...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config): ...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer. rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer. momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability. epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not. centered: Whether to normalize gradients or not.
""" """
name: str = "RMSprop" name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9 rho: float = 0.9
momentum: float = 0.0 momentum: float = 0.0
epsilon: float = 1e-7 epsilon: float = 1e-7
...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config): ...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer. epsilon: epsilon value used for numerical stability in Adam optimizer.
...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config): ...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
""" """
name: str = "Adam" name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer. epsilon: epsilon value used for numerical stability in the optimizer.
...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay. include in weight decay.
""" """
name: str = "AdamWeightDecay" name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config): ...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer. epsilon: epsilon value used for numerical stability in LAMB optimizer.
...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config): ...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded. be excluded.
""" """
name: str = "LAMB" name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-6 epsilon: float = 1e-6
......
...@@ -60,7 +60,7 @@ class OptimizerFactory(object): ...@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -88,12 +88,15 @@ class OptimizerFactory(object): ...@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get() self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type self._optimizer_type = config.optimizer.type
if self._optimizer_config is None: if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified') raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get() self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get() self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type self._warmup_type = config.warmup.type
...@@ -101,18 +104,15 @@ class OptimizerFactory(object): ...@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate. """Build learning rate.
Builds learning rate from config. Learning rate schedule is built according Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer to the learning rate config. If learning rate type is consant,
learning rate is returned. lr_config.learning_rate is returned.
Returns: Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate schedule defined, optimizer_config.learning_rate is learning rate type is consant, lr_config.learning_rate is returned.
returned.
""" """
if self._lr_type == 'constant':
# TODO(arashwan): Explore if we want to only allow explicit const lr sched. lr = self._lr_config.learning_rate
if not self._lr_config:
lr = self._optimizer_config.learning_rate
else: else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict()) lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': optimizer_type 'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
} }
} }
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type] optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config() expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'exponential', 'type': 'exponential',
...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}, },
'warmup': { 'warmup': {
'type': 'linear', 'type': 'linear',
...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
......
...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor): ...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier): def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
......
...@@ -14,23 +14,61 @@ ...@@ -14,23 +14,61 @@
# ============================================================================== # ==============================================================================
"""ALBERT classification finetuning runner in tf2.x.""" """ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_): def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -56,9 +94,14 @@ def main(_): ...@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file( albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, if FLAGS.mode == 'train_and_eval':
train_input_fn, eval_input_fn) run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file') flags.mark_flag_as_required('bert_config_file')
......
...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir): ...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated( @deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.') None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -557,7 +559,6 @@ def run_customized_training_loop( ...@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics: for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric) training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics: if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value( training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0]) train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
...@@ -24,7 +24,6 @@ import tensorflow as tf ...@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config): ...@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
...@@ -56,103 +54,18 @@ def instantiate_classification_heads_from_cfgs( ...@@ -56,103 +54,18 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else [] ] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg( def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2: ) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg) encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
input_preprocessed_data_path: str = ""
version_2_with_negative: bool = False
doc_stride: int = 128
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
query_length: int = 64
drop_remainder: bool = False
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_bertpretrainer_from_cfg(config) encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(
["encoder", "next_sentence.pooler_dense"]) encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment