Commit 45da63a9 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move trainers to core/

Move mock_task to utils/testing/

PiperOrigin-RevId: 325275356
parent 02b874a1
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.core.base_task."""
import functools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
)
class TaskKerasTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_task_with_step_override(self, distribution):
with distribution.scope():
task = mock_task.MockTask()
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
metrics=task.build_metrics(),
train_step=task.train_step,
validation_step=task.validation_step)
dataset = task.build_inputs(params=None)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
# Without specifying metrics through compile.
with distribution.scope():
train_metrics = task.build_metrics(training=True)
val_metrics = task.build_metrics(training=False)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
train_step=functools.partial(task.train_step, metrics=train_metrics),
validation_step=functools.partial(
task.validation_step, metrics=val_metrics))
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
def test_task_with_fit(self):
task = mock_task.MockTask()
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=task.build_metrics())
dataset = task.build_inputs(params=None)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
self.assertLen(model.evaluate(dataset, steps=1), 2)
def test_task_invalid_compile(self):
task = mock_task.MockTask()
model = task.build_model()
with self.assertRaises(ValueError):
_ = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=task.build_metrics(),
train_step=task.train_step)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.modeling import optimization
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
ExperimentConfig = config_definitions.ExperimentConfig
@gin.configurable
class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(self,
config: ExperimentConfig,
task: base_task.Task,
train: bool = True,
evaluate: bool = True,
model=None,
optimizer=None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
model: tf.keras.Model instance. If provided, it will be used instead
of building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._task = task
self._model = model or task.build_model()
if optimizer is None:
opt_factory = optimization.OptimizerFactory(
config.trainer.optimizer_config)
self._optimizer = opt_factory.build_optimizer(
opt_factory.build_learning_rate())
else:
self._optimizer = optimizer
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if config.runtime.loss_scale:
self._optimizer = performance.configure_optimizer(
self._optimizer,
use_float16=config.runtime.mixed_precision_dtype == 'float16',
loss_scale=config.runtime.loss_scale)
# global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer.
self._global_step = orbit.utils.create_global_step()
if hasattr(self.model, 'checkpoint_items'):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, model=self.model,
optimizer=self.optimizer, **checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32)
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
if train:
train_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
self,
train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function,
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.task.build_inputs,
self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
self,
eval_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function))
@property
def strategy(self):
return self._strategy
@property
def config(self):
return self._config
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
@property
def train_loss(self):
"""Accesses the training loss metric object."""
return self._train_loss
@property
def validation_loss(self):
"""Accesses the validation loss metric object."""
return self._validation_loss
@property
def train_metrics(self):
"""Accesses all training metric objects."""
return self._train_metrics
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
return self._validation_metrics
def initialize(self):
"""A callback function.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. Tasks may use this callback function to load a
pretrained checkpoint, saved under a directory other than the model_dir.
"""
self.task.initialize(self.model)
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def train_loop_end(self):
"""See base class."""
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(self.global_step)
else:
logs['learning_rate'] = self.optimizer.learning_rate
return logs
def train_step(self, iterator):
"""See base class."""
def step_fn(inputs):
logs = self.task.train_step(
inputs,
model=self.model,
optimizer=self.optimizer,
metrics=self.train_metrics)
self._train_loss.update_state(logs[self.task.loss])
self.global_step.assign_add(1)
self.strategy.run(step_fn, args=(next(iterator),))
def eval_begin(self):
"""Sets up metrics."""
for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states()
def eval_step(self, iterator):
"""See base class."""
def step_fn(inputs):
logs = self.task.validation_step(
inputs, model=self.model, metrics=self.validation_metrics)
self._validation_loss.update_state(logs[self.task.loss])
return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
def eval_end(self, aggregated_logs=None):
"""Processes evaluation results."""
logs = {}
for metric in self.validation_metrics + [self.validation_loss]:
logs[metric.name] = metric.result()
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics)
return logs
def eval_reduce(self, state=None, step_outputs=None):
return self.task.aggregate_logs(state, step_outputs)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib
from official.modeling.hyperparams import config_definitions as cfg
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig(
{'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}})))
def create_test_trainer(self):
task = mock_task.MockTask()
trainer = trainer_lib.Trainer(self._config, task)
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer()
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer()
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['acc'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig(
{'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}})))
task = mock_task.MockTask()
trainer = trainer_lib.Trainer(config, task)
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else:
self.assertIsInstance(
trainer.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mock task for testing."""
import dataclasses
import numpy as np
import tensorflow as tf
from official.core import base_task
from official.core import exp_factory
from official.core import task_factory
from official.modeling.hyperparams import config_definitions as cfg
class MockModel(tf.keras.Model):
def __init__(self, network):
super().__init__()
self.network = network
def call(self, inputs):
outputs = self.network(inputs)
self.add_loss(tf.reduce_mean(outputs))
return outputs
@dataclasses.dataclass
class MockTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(MockTaskConfig)
class MockTask(base_task.Task):
"""Mock task object for testing."""
def __init__(self, params=None, logging_dir=None):
super().__init__(params=params, logging_dir=logging_dir)
def build_model(self, *arg, **kwargs):
inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32)
outputs = tf.keras.layers.Dense(1)(inputs)
network = tf.keras.Model(inputs=inputs, outputs=outputs)
return MockModel(network)
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def build_inputs(self, params):
def generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
def aggregate_logs(self, state, step_outputs):
if state is None:
state = {}
for key, value in step_outputs.items():
if key not in state:
state[key] = []
state[key].append(
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
@exp_factory.register_config_factory("mock")
def mock_experiment() -> cfg.ExperimentConfig:
config = cfg.ExperimentConfig(
task=MockTaskConfig(), trainer=cfg.TrainerConfig())
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment