Commit f1bcd9bb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Initial update for the task interface.

Remove type annotation for params.
Make the trainer consume both model and task.

PiperOrigin-RevId: 334964198
parent 0ab5dcbf
...@@ -21,8 +21,6 @@ from typing import Any, Callable, Optional ...@@ -21,8 +21,6 @@ from typing import Any, Callable, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
class Task(tf.Module, metaclass=abc.ABCMeta): class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -35,11 +33,12 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -35,11 +33,12 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs. # Special keys in train/validate step returned logs.
loss = "loss" loss = "loss"
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None): def __init__(self, params, logging_dir: str = None):
"""Task initialization. """Task initialization.
Args: Args:
params: cfg.TaskConfig instance. params: the task configuration instance, which can be any of
dataclass, ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory. saved. You can also write additional stuff in this directory.
""" """
...@@ -47,7 +46,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -47,7 +46,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
self._logging_dir = logging_dir self._logging_dir = logging_dir
@property @property
def task_config(self) -> cfg.TaskConfig: def task_config(self):
return self._task_config return self._task_config
@property @property
...@@ -55,7 +54,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -55,7 +54,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return self._logging_dir return self._logging_dir
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
...@@ -83,9 +82,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -83,9 +82,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
logging.info("Finished loading pretrained checkpoint from %s", logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file) ckpt_dir_or_file)
@abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates model architecture. """[Optional] Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -128,7 +126,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -128,7 +126,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def build_inputs(self, def build_inputs(self,
params: cfg.DataConfig, params,
input_context: Optional[tf.distribute.InputContext] = None): input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions. """Returns a dataset or a nested structure of dataset functions.
...@@ -136,7 +134,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -136,7 +134,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
With distributed training, this method runs on remote hosts. With distributed training, this method runs on remote hosts.
Args: Args:
params: hyperparams to create input pipelines. params: hyperparams to create input pipelines, which can be any of
dataclass, ConfigDict, namedtuple, etc.
input_context: optional distribution input pipeline context. input_context: optional distribution input pipeline context.
Returns: Returns:
......
...@@ -39,9 +39,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -39,9 +39,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def __init__(self, def __init__(self,
config: ExperimentConfig, config: ExperimentConfig,
task: base_task.Task, task: base_task.Task,
model: tf.keras.Model,
train: bool = True, train: bool = True,
evaluate: bool = True, evaluate: bool = True,
model=None,
optimizer=None, optimizer=None,
checkpoint_exporter=None): checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
...@@ -49,12 +49,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -49,12 +49,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
Args: Args:
config: An `ExperimentConfig` instance specifying experiment config. config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance. task: A base_task.Task instance.
model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None.
train: bool, whether or not this trainer will be used for training. train: bool, whether or not this trainer will be used for training.
default to True. default to True.
evaluate: bool, whether or not this trainer will be used for evaluation. evaluate: bool, whether or not this trainer will be used for evaluation.
default to True. default to True.
model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None. used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
...@@ -65,8 +65,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -65,8 +65,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._config = config self._config = config
self._task = task self._task = task
self._model = model
self._model = model or task.build_model()
if optimizer is None: if optimizer is None:
opt_factory = optimization.OptimizerFactory( opt_factory = optimization.OptimizerFactory(
......
...@@ -54,15 +54,15 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -54,15 +54,15 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
} }
}))) })))
def create_test_trainer(self): def create_test_trainer(self, config):
task = mock_task.MockTask() task = mock_task.MockTask()
trainer = trainer_lib.Trainer(self._config, task) trainer = trainer_lib.Trainer(config, task, model=task.build_model())
return trainer return trainer
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution): def test_trainer_train(self, distribution):
with distribution.scope(): with distribution.scope():
trainer = self.create_test_trainer() trainer = self.create_test_trainer(self._config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs) self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs) self.assertIn('learning_rate', logs)
...@@ -70,7 +70,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -70,7 +70,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution): def test_trainer_validate(self, distribution):
with distribution.scope(): with distribution.scope():
trainer = self.create_test_trainer() trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs) self.assertIn('validation_loss', logs)
self.assertEqual(logs['acc'], 5. * distribution.num_replicas_in_sync) self.assertEqual(logs['acc'], 5. * distribution.num_replicas_in_sync)
...@@ -93,8 +93,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -93,8 +93,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
'type': 'constant' 'type': 'constant'
} }
}))) })))
task = mock_task.MockTask() trainer = self.create_test_trainer(config)
trainer = trainer_lib.Trainer(config, task)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None: elif mixed_precision_dtype == 'float16' and loss_scale is None:
...@@ -125,11 +124,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -125,11 +124,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
task = mock_task.MockTask(config.task, logging_dir=model_dir) task = mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer( trainer = trainer_lib.Trainer(
config, task, checkpoint_exporter=ckpt_exporter) config,
task,
model=task.build_model(),
checkpoint_exporter=ckpt_exporter)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(tf.io.gfile.exists( self.assertTrue(
os.path.join(model_dir, 'best_ckpt', 'info.json'))) tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -37,10 +37,7 @@ class BestCheckpointExporter: ...@@ -37,10 +37,7 @@ class BestCheckpointExporter:
together with orbit once this functionality is ready. together with orbit once this functionality is ready.
""" """
def __init__(self, def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
export_dir: str,
metric_name: str,
metric_comp: str):
"""Initialization. """Initialization.
Arguments: Arguments:
...@@ -53,8 +50,7 @@ class BestCheckpointExporter: ...@@ -53,8 +50,7 @@ class BestCheckpointExporter:
self._metric_name = metric_name self._metric_name = metric_name
self._metric_comp = metric_comp self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'): if self._metric_comp not in ('lower', 'higher'):
raise ValueError( raise ValueError('best checkpoint metric comp must be one of '
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp)) 'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path)) tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric() self._best_ckpt_logs = self._maybe_load_best_eval_metric()
...@@ -65,8 +61,8 @@ class BestCheckpointExporter: ...@@ -65,8 +61,8 @@ class BestCheckpointExporter:
if self._best_ckpt_logs is None or self._new_metric_is_better( if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs): self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs self._best_ckpt_logs = eval_logs
self._export_best_eval_metric( self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
checkpoint, self._best_ckpt_logs, global_step) global_step)
def _maybe_load_best_eval_metric(self): def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path): if not tf.io.gfile.exists(self.best_ckpt_logs_path):
...@@ -77,8 +73,7 @@ class BestCheckpointExporter: ...@@ -77,8 +73,7 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs): def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs.""" """Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs: if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError( raise KeyError('best checkpoint eval metric name {} is not valid. '
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format( 'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs)) self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name])) old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
...@@ -126,8 +121,7 @@ class BestCheckpointExporter: ...@@ -126,8 +121,7 @@ class BestCheckpointExporter:
return os.path.join(self._export_dir, 'best_ckpt') return os.path.join(self._export_dir, 'best_ckpt')
def maybe_create_best_ckpt_exporter( def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
params: config_definitions.ExperimentConfig,
data_dir: str) -> Any: data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config.""" """Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir export_subdir = params.trainer.best_checkpoint_export_subdir
...@@ -135,13 +129,14 @@ def maybe_create_best_ckpt_exporter( ...@@ -135,13 +129,14 @@ def maybe_create_best_ckpt_exporter(
metric_comp = params.trainer.best_checkpoint_metric_comp metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name: if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir) best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter( best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
best_ckpt_dir, metric_name, metric_comp) metric_comp)
else: else:
best_ckpt_exporter = None best_ckpt_exporter = None
logging.info('Not exporting the best checkpoint. ' logging.info(
'data_dir: %s, export_subdir: %s, metric_name: %s', 'Not exporting the best checkpoint. '
data_dir, export_subdir, metric_name) 'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
return best_ckpt_exporter return best_ckpt_exporter
...@@ -174,10 +169,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -174,10 +169,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
model = task.build_model()
trainer = train_utils.create_trainer( trainer = train_utils.create_trainer(
params, params,
task, task,
model_dir, model=model,
model_dir=model_dir,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
...@@ -200,12 +197,11 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -200,12 +197,11 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
global_step=trainer.global_step, global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop, steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if ( summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
save_summary) else None, eval_summary_dir=os.path.join(model_dir, 'validation') if
eval_summary_dir=os.path.join(model_dir, 'validation') if ( (save_summary) else None,
save_summary) else None, summary_interval=params.trainer.summary_interval if
summary_interval=params.trainer.summary_interval if ( (save_summary) else None)
save_summary) else None)
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope(): with distribution_strategy.scope():
...@@ -219,10 +215,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -219,10 +215,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif mode == 'eval': elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps) controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval': elif mode == 'continuous_eval':
def timeout_fn(): def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps: if trainer.global_step.numpy() >= params.trainer.train_steps:
return True return True
return False return False
controller.evaluate_continuously( controller.evaluate_continuously(
steps=params.trainer.validation_steps, steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout, timeout=params.trainer.continuous_eval_timeout,
......
...@@ -32,9 +32,9 @@ from official.modeling import hyperparams ...@@ -32,9 +32,9 @@ from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
def create_trainer( def create_trainer(params: config_definitions.ExperimentConfig,
params: config_definitions.ExperimentConfig,
task: base_task.Task, task: base_task.Task,
model: tf.keras.Model,
model_dir: str, model_dir: str,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
...@@ -43,7 +43,11 @@ def create_trainer( ...@@ -43,7 +43,11 @@ def create_trainer(
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
trainer = base_trainer.Trainer( trainer = base_trainer.Trainer(
params, task, train=train, evaluate=evaluate, params,
task,
train=train,
evaluate=evaluate,
model=model,
checkpoint_exporter=checkpoint_exporter) checkpoint_exporter=checkpoint_exporter)
return trainer return trainer
...@@ -129,8 +133,8 @@ def read_global_step_from_checkpoint(ckpt_file_path): ...@@ -129,8 +133,8 @@ def read_global_step_from_checkpoint(ckpt_file_path):
'make sure that your pretrain model writes ' 'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path)) 'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy() global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s', logging.info('get global_step %d from checkpoint %s', global_step_restored,
global_step_restored, ckpt_file_path) ckpt_file_path)
return global_step_restored return global_step_restored
...@@ -143,8 +147,8 @@ def write_json_summary(log_dir, global_step, eval_metrics): ...@@ -143,8 +147,8 @@ def write_json_summary(log_dir, global_step, eval_metrics):
else: else:
serializable_dict[name] = str(value) serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step)) output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s', logging.info('Evaluation results at pretrain step %d: %s', global_step,
global_step, serializable_dict) serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer: with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n') writer.write(json.dumps(serializable_dict, indent=4) + '\n')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment