# Lint as: python3 # Copyright 2020 The Orbit Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for orbit.controller.""" import os from absl import logging from absl.testing import parameterized import numpy as np from orbit import controller from orbit import standard_runner import tensorflow as tf def create_model(): x = tf.keras.layers.Input(shape=(3,), name="input") y = tf.keras.layers.Dense(4, name="dense")(x) model = tf.keras.Model(x, y) return model def summaries_with_matching_keyword(keyword, summary_dir): """Yields summary protos matching given keyword from event file.""" event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*")) for event in tf.compat.v1.train.summary_iterator(event_paths[-1]): if event.summary is not None: for value in event.summary.value: if keyword in value.tag: logging.info(event) yield event.summary def check_eventfile_for_keyword(keyword, summary_dir): """Checks event files for the keyword.""" return any(summaries_with_matching_keyword(keyword, summary_dir)) def dataset_fn(ctx): del ctx inputs = np.zeros((10, 3), dtype=np.float32) targets = np.ones((10, 4), dtype=np.float32) dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) dataset = dataset.batch(10, drop_remainder=True) return dataset class TestRunner(standard_runner.StandardTrainer, standard_runner.StandardEvaluator): """Implements the training and evaluation APIs for the test model.""" def __init__(self, return_numpy=False): self.strategy = tf.distribute.get_strategy() self.model = create_model() self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1) self.global_step = self.optimizer.iterations self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32) self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32) self.return_numpy = return_numpy train_dataset = ( self.strategy.experimental_distribute_datasets_from_function(dataset_fn) ) eval_dataset = ( self.strategy.experimental_distribute_datasets_from_function(dataset_fn) ) standard_runner.StandardTrainer.__init__(self, train_dataset) standard_runner.StandardEvaluator.__init__(self, eval_dataset) def train_step(self, iterator): def _replicated_step(inputs): """Replicated training step.""" inputs, targets = inputs with tf.GradientTape() as tape: outputs = self.model(inputs) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs)) grads = tape.gradient(loss, self.model.variables) self.optimizer.apply_gradients(zip(grads, self.model.variables)) self.train_loss.update_state(loss) self.strategy.run(_replicated_step, args=(next(iterator),)) def train_loop_end(self): train_loss = self.train_loss.result() return { "loss": train_loss.numpy() if self.return_numpy else train_loss, } def build_eval_dataset(self): return self.strategy.experimental_distribute_datasets_from_function( dataset_fn) def eval_begin(self): self.eval_loss.reset_states() def eval_step(self, iterator): def _replicated_step(inputs): """Replicated evaluation step.""" inputs, targets = inputs outputs = self.model(inputs) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs)) self.eval_loss.update_state(loss) self.strategy.run(_replicated_step, args=(next(iterator),)) def eval_end(self): eval_loss = self.eval_loss.result() return { "eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss, } class TestEvaluator(standard_runner.StandardEvaluator): """Implements the training and evaluation APIs for the test model.""" def __init__(self): self.strategy = tf.distribute.get_strategy() self.model = create_model() eval_dataset = self.strategy.experimental_distribute_datasets_from_function( dataset_fn) standard_runner.StandardEvaluator.__init__(self, eval_dataset) def eval_reduce(self, state, output): state.append(output) return state def eval_begin(self): return [] def eval_step(self, iterator): def _replicated_step(inputs): """Replicated evaluation step.""" inputs, targets = inputs outputs = self.model(inputs) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs)) return loss per_replica_losses = self.strategy.run( _replicated_step, args=(next(iterator),)) mean_loss = self.strategy.reduce( tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) return mean_loss def eval_end(self, outputs): return { "eval_loss": tf.reduce_mean(outputs), } class TestTrainerWithSummaries(standard_runner.StandardTrainer): """A Trainer model with summaries for testing purposes.""" def __init__(self): self.strategy = tf.distribute.get_strategy() self.model = create_model() self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1) self.global_step = self.optimizer.iterations self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32) train_dataset = ( self.strategy.experimental_distribute_datasets_from_function(dataset_fn) ) standard_runner.StandardTrainer.__init__( self, train_dataset, use_tpu_summary_optimization=True) def build_train_dataset(self): return self.strategy.experimental_distribute_datasets_from_function( dataset_fn) def train_step(self, iterator): def _replicated_step(inputs): """Replicated training step.""" inputs, targets = inputs with tf.GradientTape() as tape: outputs = self.model(inputs) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs)) tf.summary.scalar("loss", loss) grads = tape.gradient(loss, self.model.variables) self.optimizer.apply_gradients(zip(grads, self.model.variables)) self.train_loss.update_state(loss) self.strategy.run(_replicated_step, args=(next(iterator),)) class ControllerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.model_dir = self.get_temp_dir() def test_no_checkpoint(self): test_runner = TestRunner() # No checkpoint manager and no strategy. test_controller = controller.Controller( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) self.assertEqual(test_runner.global_step, 10) # Loss and accuracy values should be written into summaries. self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) self.assertTrue( check_eventfile_for_keyword( "loss", os.path.join(self.model_dir, "summaries/train"))) self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) self.assertTrue( check_eventfile_for_keyword( "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) # No checkpoint, so global step starts from 0. test_runner.global_step.assign(0) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) self.assertEqual(test_runner.global_step, 10) def test_no_checkpoint_and_summaries(self): test_runner = TestRunner() # No checkpoint + summary directories. test_controller = controller.Controller( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) self.assertEqual(test_runner.global_step, 10) @parameterized.named_parameters(("return_numpy", True), ("return_tensor", False)) def test_train_and_evaluate(self, return_numpy): test_runner = TestRunner(return_numpy=return_numpy) checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step, checkpoint_interval=10) test_controller = controller.Controller( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), checkpoint_manager=checkpoint_manager, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) # Checkpoints are saved. self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) # Loss and accuracy values should be written into summaries. self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) self.assertTrue( check_eventfile_for_keyword( "loss", os.path.join(self.model_dir, "summaries/train"))) self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) self.assertTrue( check_eventfile_for_keyword( "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) def test_train_only(self): test_runner = TestRunner() checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step, checkpoint_interval=10) test_controller = controller.Controller( trainer=test_runner, global_step=test_runner.global_step, steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), checkpoint_manager=checkpoint_manager, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), ) test_controller.train(steps=10) # Checkpoints are saved. self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) # Only train summaries are written. self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) self.assertTrue( check_eventfile_for_keyword( "loss", os.path.join(self.model_dir, "summaries/train"))) self.assertFalse( tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) def test_evaluate_only(self): test_runner = TestRunner() checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint.save(os.path.join(self.model_dir, "ckpt")) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step) test_controller = controller.Controller( evaluator=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(self.model_dir, "summaries/train"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.evaluate(steps=2) # Only eval summaries are written self.assertFalse( tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train"))) self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) self.assertTrue( check_eventfile_for_keyword( "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) # Tests continuous eval with timeout and timeout_fn. done_file = os.path.join(self.model_dir, "summaries/eval/Done") def timeout_fn(): with tf.io.gfile.GFile(done_file, "w") as f: f.write("DONE") return True test_controller = controller.Controller( evaluator=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.evaluate_continuously( timeout=1, timeout_fn=timeout_fn, steps=2) self.assertNotEmpty(tf.io.gfile.glob(done_file)) def test_no_eval_steps(self): test_runner = TestRunner() checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint.save(os.path.join(self.model_dir, "ckpt")) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step) test_controller = controller.Controller( evaluator=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager) test_controller.evaluate() def test_already_trained_model(self): test_runner = TestRunner() test_runner.global_step.assign(10) checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step, checkpoint_interval=10) test_controller = controller.Controller( trainer=test_runner, global_step=test_runner.global_step, steps_per_loop=2, checkpoint_manager=checkpoint_manager) # `global_step` is already `train_steps`. test_controller.train(steps=10) def test_summaries_inside_train_fn(self): test_runner = TestTrainerWithSummaries() checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step) test_controller = controller.Controller( trainer=test_runner, global_step=test_runner.global_step, steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_interval=2, checkpoint_manager=checkpoint_manager, ) test_controller.train(steps=10) # Checkpoints are saved. self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) # Only train summaries are written. self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) self.assertTrue( check_eventfile_for_keyword( "loss", os.path.join(self.model_dir, "summaries/train"))) self.assertFalse( tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) def test_train_and_evaluate_with_same_summary_dir(self): test_runner = TestRunner() checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step) test_controller = controller.Controller( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries"), checkpoint_manager=checkpoint_manager, eval_summary_dir=os.path.join(self.model_dir, "summaries")) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) # Loss and accuracy values should be written into summaries. self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries"))) self.assertTrue( check_eventfile_for_keyword("loss", os.path.join(self.model_dir, "summaries"))) self.assertTrue( check_eventfile_for_keyword("eval_loss", os.path.join(self.model_dir, "summaries"))) def test_early_stop_on_eval_loss(self): test_runner = TestRunner() class EarlyStopController(controller.Controller): """A subclass of Controller supports early stopping.""" def train_and_evaluate(self, train_steps: int = None, eval_steps: int = None, eval_interval: int = None): while self.global_step.numpy() < train_steps: interval = min(train_steps - self.global_step.numpy(), eval_interval) num_steps = self.global_step.numpy() + interval self.train(steps=num_steps, checkpoint_at_completion=False) self.evaluate(steps=eval_steps) # Early stop condition. if test_runner.eval_loss.result() < 0.1: logging.info( "Training early stopped as eval_loss %s is less than 0.1", test_runner.eval_loss.result()) return checkpoint = tf.train.Checkpoint( model=test_runner.model, optimizer=test_runner.optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None, step_counter=test_runner.global_step, checkpoint_interval=10) test_controller = EarlyStopController( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2, checkpoint_manager=checkpoint_manager) test_controller.train_and_evaluate( train_steps=10, eval_steps=6, eval_interval=2) self.assertLess(test_runner.global_step, 10) def test_evaluate_with_loss_outputs(self): test_evaluator = TestEvaluator() checkpoint = tf.train.Checkpoint(model=test_evaluator.model) checkpoint.save(os.path.join(self.model_dir, "ckpt")) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.model_dir, max_to_keep=None) test_controller = controller.Controller( evaluator=test_evaluator, global_step=tf.Variable(0, dtype=tf.int64), checkpoint_manager=checkpoint_manager, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.evaluate(steps=5) # Only eval summaries are written self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) self.assertTrue( check_eventfile_for_keyword( "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) def test_train_and_evaluate_reset_datasets(self): test_runner = TestRunner() test_controller = controller.Controller( trainer=test_runner, evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=2) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) train_dataset = ( test_runner.strategy.experimental_distribute_datasets_from_function( dataset_fn)) eval_dataset = ( test_runner.strategy.experimental_distribute_datasets_from_function( dataset_fn)) test_runner.train_dataset = train_dataset test_runner.eval_dataset = eval_dataset test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) if __name__ == "__main__": tf.test.main()