Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
...@@ -237,7 +237,7 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None): ...@@ -237,7 +237,7 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
(boxes[j, k, 3] - boxes[j, k, 1]) * (boxes[j, k, 3] - boxes[j, k, 1]) *
(boxes[j, k, 2] - boxes[j, k, 0])) (boxes[j, k, 2] - boxes[j, k, 0]))
if 'masks' in groundtruths: if 'masks' in groundtruths:
mask = Image.open(six.StringIO(groundtruths['masks'][i][j, k])) mask = Image.open(six.BytesIO(groundtruths['masks'][i][j, k]))
width, height = mask.size width, height = mask.size
np_mask = ( np_mask = (
np.array(mask.getdata()).reshape(height, width).astype(np.uint8)) np.array(mask.getdata()).reshape(height, width).astype(np.uint8))
......
...@@ -77,11 +77,13 @@ def multilevel_features_generator(params): ...@@ -77,11 +77,13 @@ def multilevel_features_generator(params):
def retinanet_head_generator(params): def retinanet_head_generator(params):
"""Generator function for RetinaNet head architecture.""" """Generator function for RetinaNet head architecture."""
head_params = params.retinanet_head head_params = params.retinanet_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RetinanetHead( return heads.RetinanetHead(
params.architecture.min_level, params.architecture.min_level,
params.architecture.max_level, params.architecture.max_level,
params.architecture.num_classes, params.architecture.num_classes,
head_params.anchors_per_location, anchors_per_location,
head_params.num_convs, head_params.num_convs,
head_params.num_filters, head_params.num_filters,
head_params.use_separable_conv, head_params.use_separable_conv,
...@@ -91,10 +93,12 @@ def retinanet_head_generator(params): ...@@ -91,10 +93,12 @@ def retinanet_head_generator(params):
def rpn_head_generator(params): def rpn_head_generator(params):
"""Generator function for RPN head architecture.""" """Generator function for RPN head architecture."""
head_params = params.rpn_head head_params = params.rpn_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RpnHead( return heads.RpnHead(
params.architecture.min_level, params.architecture.min_level,
params.architecture.max_level, params.architecture.max_level,
head_params.anchors_per_location, anchors_per_location,
head_params.num_convs, head_params.num_convs,
head_params.num_filters, head_params.num_filters,
head_params.use_separable_conv, head_params.use_separable_conv,
......
...@@ -23,6 +23,7 @@ from tensorflow.python.keras import backend ...@@ -23,6 +23,7 @@ from tensorflow.python.keras import backend
try: try:
from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import
keras_tensor.disable_keras_tensors()
except ImportError: except ImportError:
keras_tensor = None keras_tensor = None
......
...@@ -449,7 +449,7 @@ class RetinanetBoxLoss(object): ...@@ -449,7 +449,7 @@ class RetinanetBoxLoss(object):
num_positives: number of positive examples in the minibatch. num_positives: number of positive examples in the minibatch.
Returns: Returns:
an integar tensor representing total box regression loss. an integer tensor representing total box regression loss.
""" """
# Sums all positives in a batch for normalization and avoids zero # Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training # num_positives_sum, which would lead to inf loss during training
...@@ -457,7 +457,6 @@ class RetinanetBoxLoss(object): ...@@ -457,7 +457,6 @@ class RetinanetBoxLoss(object):
box_losses = [] box_losses = []
for level in box_outputs.keys(): for level in box_outputs.keys():
# Onehot encoding for classification labels.
box_targets_l = labels[level] box_targets_l = labels[level]
box_losses.append( box_losses.append(
self.box_loss(box_outputs[level], box_targets_l, num_positives_sum)) self.box_loss(box_outputs[level], box_targets_l, num_positives_sum))
......
...@@ -59,11 +59,8 @@ class RetinanetModel(base_model.Model): ...@@ -59,11 +59,8 @@ class RetinanetModel(base_model.Model):
self._transpose_input = params.train.transpose_input self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supported.' assert not self._transpose_input, 'Transpose input is not supported.'
# Input layer. # Input layer.
input_shape = (
params.retinanet_parser.output_size +
[params.retinanet_parser.num_channels])
self._input_layer = tf.keras.layers.Input( self._input_layer = tf.keras.layers.Input(
shape=input_shape, name='', shape=(None, None, params.retinanet_parser.num_channels), name='',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32) dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
def build_outputs(self, inputs, mode): def build_outputs(self, inputs, mode):
......
...@@ -40,8 +40,6 @@ model: ...@@ -40,8 +40,6 @@ model:
momentum: 0.9 momentum: 0.9
decay: 0.9 decay: 0.9
epsilon: 0.001 epsilon: 0.001
learning_rate:
name: 'piecewise_constant_with_warmup'
loss: loss:
label_smoothing: 0.1 label_smoothing: 0.1
train: train:
......
...@@ -43,8 +43,6 @@ model: ...@@ -43,8 +43,6 @@ model:
epsilon: 0.001 epsilon: 0.001
moving_average_decay: 0. moving_average_decay: 0.
lookahead: False lookahead: False
learning_rate:
name: 'piecewise_constant_with_warmup'
loss: loss:
label_smoothing: 0.1 label_smoothing: 0.1
train: train:
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import Any, List, Mapping from typing import Any, Mapping, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -32,23 +32,33 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -32,23 +32,33 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__( def __init__(
self, self,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule, lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int): warmup_steps: int,
warmup_lr: Optional[float] = None):
"""Add warmup decay to a learning rate schedule. """Add warmup decay to a learning rate schedule.
Args: Args:
lr_schedule: base learning rate scheduler lr_schedule: base learning rate scheduler
warmup_steps: number of warmup steps warmup_steps: number of warmup steps
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
field.
""" """
super(WarmupDecaySchedule, self).__init__() super(WarmupDecaySchedule, self).__init__()
self._lr_schedule = lr_schedule self._lr_schedule = lr_schedule
self._warmup_steps = warmup_steps self._warmup_steps = warmup_steps
self._warmup_lr = warmup_lr
def __call__(self, step: int): def __call__(self, step: int):
lr = self._lr_schedule(step) lr = self._lr_schedule(step)
if self._warmup_steps: if self._warmup_steps:
initial_learning_rate = tf.convert_to_tensor( if self._warmup_lr is not None:
self._lr_schedule.initial_learning_rate, name="initial_learning_rate") initial_learning_rate = tf.convert_to_tensor(
self._warmup_lr, name="initial_learning_rate")
else:
initial_learning_rate = tf.convert_to_tensor(
self._lr_schedule.initial_learning_rate,
name="initial_learning_rate")
dtype = initial_learning_rate.dtype dtype = initial_learning_rate.dtype
global_step_recomp = tf.cast(step, dtype) global_step_recomp = tf.cast(step, dtype)
warmup_steps = tf.cast(self._warmup_steps, dtype) warmup_steps = tf.cast(self._warmup_steps, dtype)
...@@ -62,65 +72,11 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -62,65 +72,11 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
config = self._lr_schedule.get_config() config = self._lr_schedule.get_config()
config.update({ config.update({
"warmup_steps": self._warmup_steps, "warmup_steps": self._warmup_steps,
"warmup_lr": self._warmup_lr,
}) })
return config return config
# TODO(b/149030439) - refactor this with
# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule.
class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule."""
def __init__(self,
batch_size: int,
epoch_size: int,
warmup_epochs: int,
boundaries: List[int],
multipliers: List[float]):
"""Piecewise constant decay with warmup.
Args:
batch_size: The training batch size used in the experiment.
epoch_size: The size of an epoch, or the number of examples in an epoch.
warmup_epochs: The number of warmup epochs to apply.
boundaries: The list of floats with strictly increasing entries.
multipliers: The list of multipliers/learning rates to use for the
piecewise portion. The length must be 1 less than that of boundaries.
"""
super(PiecewiseConstantDecayWithWarmup, self).__init__()
if len(boundaries) != len(multipliers) - 1:
raise ValueError("The length of boundaries must be 1 less than the "
"length of multipliers")
base_lr_batch_size = 256
steps_per_epoch = epoch_size // batch_size
self._rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self._step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
self._lr_values = [self._rescaled_lr * m for m in multipliers]
self._warmup_steps = warmup_epochs * steps_per_epoch
def __call__(self, step: int):
"""Compute learning rate at given step."""
def warmup_lr():
return self._rescaled_lr * (
step / tf.cast(self._warmup_steps, tf.float32))
def piecewise_lr():
return tf.compat.v1.train.piecewise_constant(
tf.cast(step, tf.float32), self._step_boundaries, self._lr_values)
return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr)
def get_config(self) -> Mapping[str, Any]:
return {
"rescaled_lr": self._rescaled_lr,
"step_boundaries": self._step_boundaries,
"lr_values": self._lr_values,
"warmup_steps": self._warmup_steps,
}
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor.""" """Class to generate learning rate tensor."""
......
...@@ -46,44 +46,6 @@ class LearningRateTests(tf.test.TestCase): ...@@ -46,44 +46,6 @@ class LearningRateTests(tf.test.TestCase):
self.assertAllClose(self.evaluate(lr(step)), self.assertAllClose(self.evaluate(lr(step)),
step / warmup_steps * initial_lr) step / warmup_steps * initial_lr)
def test_piecewise_constant_decay_with_warmup(self):
"""Basic computational test for piecewise constant decay with warmup."""
boundaries = [1, 2, 3]
warmup_epochs = boundaries[0]
learning_rate_multipliers = [1.0, 0.1, 0.001]
expected_keys = [
'rescaled_lr', 'step_boundaries', 'lr_values', 'warmup_steps',
]
expected_lrs = [0.0, 0.1, 0.1]
lr = learning_rate.PiecewiseConstantDecayWithWarmup(
batch_size=256,
epoch_size=256,
warmup_epochs=warmup_epochs,
boundaries=boundaries[1:],
multipliers=learning_rate_multipliers)
step = 0
config = lr.get_config()
self.assertAllInSet(list(config.keys()), expected_keys)
for boundary, expected_lr in zip(boundaries, expected_lrs):
for _ in range(step, boundary):
self.assertAllClose(self.evaluate(lr(step)), expected_lr)
step += 1
def test_piecewise_constant_decay_invalid_boundaries(self):
with self.assertRaisesRegex(ValueError,
'The length of boundaries must be 1 less '):
learning_rate.PiecewiseConstantDecayWithWarmup(
batch_size=256,
epoch_size=256,
warmup_epochs=1,
boundaries=[1, 2],
multipliers=[1, 2])
def test_cosine_decay_with_warmup(self): def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup.""" """Basic computational test for cosine decay with warmup."""
expected_lrs = [0.0, 0.1, 0.05, 0.0] expected_lrs = [0.0, 0.1, 0.05, 0.0]
......
...@@ -370,29 +370,26 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -370,29 +370,26 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_rate=decay_rate, decay_rate=decay_rate,
staircase=params.staircase) staircase=params.staircase)
elif decay_type == 'piecewise_constant_with_warmup': elif decay_type == 'stepwise':
logging.info('Using Piecewise constant decay with warmup. ' steps_per_epoch = params.examples_per_epoch // batch_size
'Parameters: batch_size: %d, epoch_size: %d, ' boundaries = [boundary * steps_per_epoch for boundary in params.boundaries]
'warmup_epochs: %d, boundaries: %s, multipliers: %s', multipliers = [batch_size * multiplier for multiplier in params.multipliers]
batch_size, params.examples_per_epoch, logging.info('Using stepwise learning rate. Parameters: '
params.warmup_epochs, params.boundaries, 'boundaries: %s, values: %s',
params.multipliers) boundaries, multipliers)
lr = learning_rate.PiecewiseConstantDecayWithWarmup( lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
batch_size=batch_size, boundaries=boundaries,
epoch_size=params.examples_per_epoch, values=multipliers)
warmup_epochs=params.warmup_epochs,
boundaries=params.boundaries,
multipliers=params.multipliers)
elif decay_type == 'cosine_with_warmup': elif decay_type == 'cosine_with_warmup':
lr = learning_rate.CosineDecayWithWarmup( lr = learning_rate.CosineDecayWithWarmup(
batch_size=batch_size, batch_size=batch_size,
total_steps=train_epochs * train_steps, total_steps=train_epochs * train_steps,
warmup_steps=warmup_steps) warmup_steps=warmup_steps)
if warmup_steps > 0: if warmup_steps > 0:
if decay_type not in [ if decay_type not in ['cosine_with_warmup']:
'piecewise_constant_with_warmup', 'cosine_with_warmup'
]:
logging.info('Applying %d warmup steps to the learning rate', logging.info('Applying %d warmup steps to the learning rate',
warmup_steps) warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps) lr = learning_rate.WarmupDecaySchedule(lr,
warmup_steps,
warmup_lr=base_lr)
return lr return lr
...@@ -93,7 +93,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -93,7 +93,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('exponential', 'exponential'), ('exponential', 'exponential'),
('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup'),
('cosine_with_warmup', 'cosine_with_warmup')) ('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type): def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax.""" """Basic smoke test for syntax."""
......
...@@ -18,22 +18,12 @@ from __future__ import absolute_import ...@@ -18,22 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import Any, Mapping
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
_RESNET_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
_RESNET_LR_BOUNDARIES = list(p[1] for p in _RESNET_LR_SCHEDULE[1:])
_RESNET_LR_MULTIPLIERS = list(p[0] for p in _RESNET_LR_SCHEDULE)
_RESNET_LR_WARMUP_EPOCHS = _RESNET_LR_SCHEDULE[0][1]
@dataclasses.dataclass @dataclasses.dataclass
class ResNetModelConfig(base_configs.ModelConfig): class ResNetModelConfig(base_configs.ModelConfig):
"""Configuration for the ResNet model.""" """Configuration for the ResNet model."""
...@@ -56,8 +46,13 @@ class ResNetModelConfig(base_configs.ModelConfig): ...@@ -56,8 +46,13 @@ class ResNetModelConfig(base_configs.ModelConfig):
moving_average_decay=None) moving_average_decay=None)
learning_rate: base_configs.LearningRateConfig = ( learning_rate: base_configs.LearningRateConfig = (
base_configs.LearningRateConfig( base_configs.LearningRateConfig(
name='piecewise_constant_with_warmup', name='stepwise',
initial_lr=0.1,
examples_per_epoch=1281167, examples_per_epoch=1281167,
warmup_epochs=_RESNET_LR_WARMUP_EPOCHS, boundaries=[30, 60, 80],
boundaries=_RESNET_LR_BOUNDARIES, warmup_epochs=5,
multipliers=_RESNET_LR_MULTIPLIERS)) scale_by_batch_size=1. / 256.,
multipliers=[0.1 / 256,
0.01 / 256,
0.001 / 256,
0.0001 / 256]))
...@@ -167,6 +167,7 @@ def run(flags_obj): ...@@ -167,6 +167,7 @@ def run(flags_obj):
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_interval=summary_interval, summary_interval=summary_interval,
summary_dir=flags_obj.model_dir,
eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))
time_callback.on_train_begin() time_callback.on_train_begin()
......
...@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
.datasets_num_private_threads, .datasets_num_private_threads,
dtype=self.dtype, dtype=self.dtype,
drop_remainder=True) drop_remainder=True)
orbit.StandardTrainer.__init__(self, train_dataset, orbit.StandardTrainer.__init__(
flags_obj.use_tf_while_loop, self,
flags_obj.use_tf_function) train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=flags_obj.use_tf_while_loop,
use_tf_function=flags_obj.use_tf_function))
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_dataset = orbit.utils.make_distributed_dataset( eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.strategy,
...@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
batch_size=self.batch_size, batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record, parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype) dtype=self.dtype)
orbit.StandardEvaluator.__init__(self, eval_dataset, orbit.StandardEvaluator.__init__(
flags_obj.use_tf_function) self,
eval_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=flags_obj.use_tf_function))
def train_loop_begin(self): def train_loop_begin(self):
"""See base class.""" """See base class."""
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,8 +15,9 @@ ...@@ -16,8 +15,9 @@
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
import time import time
from typing import Callable, Optional, Text, Union from typing import Callable, Dict, Optional, Text, Union
from absl import logging from absl import logging
import numpy as np
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
...@@ -71,9 +71,11 @@ class Controller: ...@@ -71,9 +71,11 @@ class Controller:
`trainer.train` function will always be enabled. If set, the value `trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop. should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries. summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`. For example, You can set it to `checkpoint_manager.directory`.
If None, it will not write training summarizes.
eval_summary_dir: The directory to write eval summaries. If None, it will eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`. be set to `summary_dir`. If both `summary_dir` and `eval_summary_dir`
are None, it will not write evaluation summarizes.
Raises: Raises:
ValueError: If both `trainer` and `evaluator` are None. ValueError: If both `trainer` and `evaluator` are None.
...@@ -108,9 +110,6 @@ class Controller: ...@@ -108,9 +110,6 @@ class Controller:
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
if summary_dir is None and checkpoint_manager:
summary_dir = checkpoint_manager.directory
if self.trainer is not None: if self.trainer is not None:
self.step_timer = None self.step_timer = None
self.steps_per_loop = steps_per_loop self.steps_per_loop = steps_per_loop
...@@ -118,7 +117,6 @@ class Controller: ...@@ -118,7 +117,6 @@ class Controller:
self.summary_manager = utils.SummaryManager( self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step) summary_dir, tf.summary.scalar, global_step=self.global_step)
eval_summary_writer = None
if self.evaluator is not None: if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None: if eval_summary_dir == summary_dir and self.trainer is not None:
...@@ -177,7 +175,7 @@ class Controller: ...@@ -177,7 +175,7 @@ class Controller:
if checkpoint_at_completion: if checkpoint_at_completion:
self.save_checkpoint() self.save_checkpoint()
def evaluate(self, steps: int = None): def evaluate(self, steps: int = None) -> Optional[Dict[Text, np.number]]:
"""Runs evaluation. """Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps` This method calls the `evaluate` method on the Evaluator object for `steps`
...@@ -186,10 +184,12 @@ class Controller: ...@@ -186,10 +184,12 @@ class Controller:
Args: Args:
steps: The number of steps to evaluate for. steps: The number of steps to evaluate for.
Returns:
The evaluation results as a dictionary of numpy values.
Raises: Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`. ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided. ValueError: If `evaluator` is not provided.
""" """
if self.evaluator is None: if self.evaluator is None:
raise ValueError("`evaluator` must be provided to call `evaluate()` " raise ValueError("`evaluator` must be provided to call `evaluate()` "
...@@ -204,7 +204,7 @@ class Controller: ...@@ -204,7 +204,7 @@ class Controller:
else: else:
logging.info("Evaluating at train step: %s", current_step) logging.info("Evaluating at train step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default(): with self.eval_summary_manager.summary_writer().as_default():
eval_outputs = self.evaluator.evaluate(steps) eval_outputs = self.evaluator.evaluate(steps)
if eval_outputs: if eval_outputs:
...@@ -217,6 +217,8 @@ class Controller: ...@@ -217,6 +217,8 @@ class Controller:
self.eval_summary_manager.write_summaries(eval_outputs) self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush() self.eval_summary_manager.flush()
return eval_outputs
def restore_checkpoint(self, checkpoint_path: Text = None): def restore_checkpoint(self, checkpoint_path: Text = None):
"""Restore or initialize the model. """Restore or initialize the model.
...@@ -334,7 +336,7 @@ class Controller: ...@@ -334,7 +336,7 @@ class Controller:
current_step += num_steps current_step += num_steps
num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32) num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default(): with self.summary_manager.summary_writer().as_default():
# Create a lambda that returns true when summaries should be written. # Create a lambda that returns true when summaries should be written.
should_record = False # Allows static optimization in no-summary cases. should_record = False # Allows static optimization in no-summary cases.
if self.summary_interval: if self.summary_interval:
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -158,6 +157,57 @@ class TestEvaluator(standard_runner.StandardEvaluator): ...@@ -158,6 +157,57 @@ class TestEvaluator(standard_runner.StandardEvaluator):
} }
class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
dataset2 = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
self.loss = tf.keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy = tf.keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
self.loss2 = tf.keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy2 = tf.keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
standard_runner.StandardEvaluator.__init__(
self, eval_dataset={
"dataset": dataset,
"dataset2": dataset2
})
def eval_step(self, iterator):
def _replicated_step(loss, accuracy, inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss.update_state(tf.keras.losses.MSE(targets, outputs))
accuracy.update_state(targets, outputs)
self.strategy.run(
lambda inputs: _replicated_step(self.loss, self.accuracy, inputs),
args=(next(iterator["dataset"]),))
self.strategy.run(
lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs),
args=(next(iterator["dataset2"]),))
def eval_end(self):
return {
"dataset": {
"loss": self.loss.result(),
"accuracy": self.accuracy.result()
},
"dataset2": {
"loss": self.loss2.result(),
"accuracy": self.accuracy2.result()
},
}
class TestTrainerWithSummaries(standard_runner.StandardTrainer): class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes.""" """A Trainer model with summaries for testing purposes."""
...@@ -171,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer): ...@@ -171,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
self.strategy.experimental_distribute_datasets_from_function(dataset_fn) self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
) )
standard_runner.StandardTrainer.__init__( standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True) self,
train_dataset,
options=standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True))
def build_train_dataset(self): def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function( return self.strategy.experimental_distribute_datasets_from_function(
...@@ -241,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -241,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10) self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# No summaries are saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Training summaries are not saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
# Evaluation summaries are saved.
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
@parameterized.named_parameters(("return_numpy", True), @parameterized.named_parameters(("return_numpy", True),
("return_tensor", False)) ("return_tensor", False))
def test_train_and_evaluate(self, return_numpy): def test_train_and_evaluate(self, return_numpy):
...@@ -329,7 +432,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -329,7 +432,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=2) eval_results = test_controller.evaluate(steps=2)
# Only eval summaries are written # Only eval summaries are written
self.assertFalse( self.assertFalse(
...@@ -339,6 +442,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -339,6 +442,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEmpty( self.assertNotEmpty(
summaries_with_matching_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
self.assertIn("eval_loss", eval_results)
# Tests continuous eval with timeout and timeout_fn. # Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done") done_file = os.path.join(self.model_dir, "summaries/eval/Done")
...@@ -558,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -558,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner, evaluator=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
steps_per_loop=10, steps_per_loop=10,
checkpoint_manager=checkpoint_manager) checkpoint_manager=checkpoint_manager,
summary_dir=self.model_dir)
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5) train_steps=10, eval_steps=2, eval_interval=5)
...@@ -569,6 +674,31 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -569,6 +674,31 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertLen( self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2) summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
def test_evaluate_with_nested_summaries(self):
test_evaluator = TestEvaluatorWithNestedSummary()
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=self.model_dir)
test_controller.evaluate(steps=5)
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset2")))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -35,7 +34,7 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta): ...@@ -35,7 +34,7 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
large in Eager mode. It is usually encouraged to create a host training loop large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.run` inside a (e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host `tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple training loop to achieve peak performance, users can just implement a simple
python loop to drive each step. python loop to drive each step.
Args: Args:
...@@ -45,7 +44,8 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta): ...@@ -45,7 +44,8 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
Returns: Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries. will be written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
""" """
pass pass
...@@ -67,6 +67,7 @@ class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta): ...@@ -67,6 +67,7 @@ class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
Returns: Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries. will be written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
""" """
pass pass
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -24,20 +23,22 @@ import tensorflow as tf ...@@ -24,20 +23,22 @@ import tensorflow as tf
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TrainerOverrides: class StandardTrainerOptions:
"""Advanced overrides for Orbit trainers. """Advanced options for `orbit.StandardTrainer`.
Attributes: Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with use_tf_while_loop: A boolean indicating whether to run the training loop
a `tf.while_loop`. using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicates whether a `tf.function` will be used. use_tf_function: A boolean indicating whether to apply `tf.function` to the
If False, training will run on pure eager mode. training loop. This will only affect the body of the loop (involving
use_tpu_summary_optimization: A boolean indicates whether to enable the `train_step`); `train_loop_begin` and `train_loop_end` will always be run
performance optimization for summaries in TPUs. In TPUs, writing in eager mode.
summaries with outside compilation inside train step is slow. If True, use_tpu_summary_optimization: A boolean indicating whether to enable a
it creates two `tf.function` with two XLA programs: one with summaries performance optimization for summaries in TPUs. Writing summaries
and one without, and run the program with summaries (slow one) only if conditionally with outside compilation on TPUs can be extremely slow. If
necessary. `True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
""" """
use_tf_while_loop: bool = True use_tf_while_loop: bool = True
use_tf_function: bool = True use_tf_function: bool = True
...@@ -47,39 +48,29 @@ class TrainerOverrides: ...@@ -47,39 +48,29 @@ class TrainerOverrides:
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self, def __init__(self, train_dataset, options: StandardTrainerOptions = None):
train_dataset,
use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
"""Construct a `StandardTrainer` object. """Construct a `StandardTrainer` object.
Args: Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset. DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with options: An `orbit.StandardTrainerOptions` instance.
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
""" """
if use_tf_while_loop and not use_tf_function: options = options or StandardTrainerOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported") "is not supported")
if use_tpu_summary_optimization and not use_tf_while_loop: if options.use_tpu_summary_optimization and not options.use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and " raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported") "`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = use_tf_while_loop
self._use_tf_function = use_tf_function self._use_tf_while_loop = options.use_tf_while_loop
self._use_tf_function = options.use_tf_function
self._use_tpu_summary_optimization = options.use_tpu_summary_optimization
self._train_dataset = train_dataset self._train_dataset = train_dataset
self._train_iter = None self._train_iter = None
self._train_loop_fn = None self._train_loop_fn = None
self._use_tpu_summary_optimization = use_tpu_summary_optimization
def train(self, def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
...@@ -144,7 +135,8 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -144,7 +135,8 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
Returns: Returns:
The function may return a dictionary of `Tensors`, which will be The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries. written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
""" """
pass pass
...@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class EvaluatorOverrides: class StandardEvaluatorOptions:
"""Advanced overrides for Orbit evaluators. """Advanced options for the `orbit.StandardEvaluator`.
Attributes: Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used. use_tf_function: A boolean indicating whether to apply `tf.function` to the
If False, training will run on pure eager mode. training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
""" """
use_tf_function: bool = True use_tf_function: bool = True
...@@ -181,16 +175,16 @@ class EvaluatorOverrides: ...@@ -181,16 +175,16 @@ class EvaluatorOverrides:
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True): def __init__(self, eval_dataset, options: StandardEvaluatorOptions = None):
"""Construct a `StandardEvaluator` object. """Construct a `StandardEvaluator` object.
Args: Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset. DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used. options: An `orbit.StandardEvaluatorOptions` instance.
If False, evaluation will run on pure eager mode.
""" """
self._eval_use_tf_function = use_tf_function options = options or StandardEvaluatorOptions()
self._eval_use_tf_function = options.use_tf_function
self._eval_dataset = eval_dataset self._eval_dataset = eval_dataset
self._eval_loop_fn = None self._eval_loop_fn = None
...@@ -261,7 +255,8 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -261,7 +255,8 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
Returns: Returns:
The function may return a dictionary of `Tensors`, which will be The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries. written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
""" """
pass pass
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,9 +13,9 @@ ...@@ -14,9 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for orbit.standard_runner.""" """Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from orbit import standard_runner from orbit import standard_runner
from orbit import utils
import tensorflow as tf import tensorflow as tf
...@@ -34,46 +33,49 @@ def dataset_fn(input_context=None): ...@@ -34,46 +33,49 @@ def dataset_fn(input_context=None):
return dataset return dataset
class TestRunner(standard_runner.StandardTrainer, class TestTrainer(standard_runner.StandardTrainer):
standard_runner.StandardEvaluator): """A StandardTrainer subclass for tests."""
"""Implements the training and evaluation APIs for tests."""
def __init__(self): def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable( self.global_step = utils.create_global_step()
0, distribute = self.strategy.experimental_distribute_datasets_from_function
trainable=False, dataset = distribute(dataset_fn)
dtype=tf.int64, super().__init__(train_dataset=dataset, options=options)
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
def train_loop_begin(self): def train_loop_begin(self):
self.train_dataset = ( self.global_step.assign(0)
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
def train_step(self, iterator): def train_step(self, iterator):
def _replicated_step(_): def replica_step(_):
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self): def train_loop_end(self):
return self.global_step.numpy() return self.global_step.numpy()
class TestEvaluator(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = utils.create_global_step()
distribute = self.strategy.experimental_distribute_datasets_from_function
dataset = distribute(dataset_fn)
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self): def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function( self.global_step.assign(0)
dataset_fn)
def eval_step(self, iterator): def eval_step(self, iterator):
def _replicated_step(_): def replica_step(_):
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(replica_step, args=(next(iterator),))
def eval_end(self): def eval_end(self):
return self.global_step.numpy() return self.global_step.numpy()
...@@ -81,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer, ...@@ -81,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer,
class StandardRunnerTest(tf.test.TestCase): class StandardRunnerTest(tf.test.TestCase):
def test_train(self): def test_default_trainer(self):
test_runner = TestRunner() trainer = TestTrainer()
self.assertEqual( self.assertEqual(trainer.train(tf.constant(10)), 10)
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_trainer_with_tpu_summary_optimization(self):
options = standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True)
trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10)
def test_eval(self): def test_default_evaluator(self):
test_runner = TestRunner() evaluator = TestEvaluator()
self.assertEqual( self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment