Commit f1bcd9bb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Initial update for the task interface.

Remove type annotation for params.
Make the trainer consume both model and task.

PiperOrigin-RevId: 334964198
parent 0ab5dcbf
......@@ -21,8 +21,6 @@ from typing import Any, Callable, Optional
from absl import logging
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure.
......@@ -35,11 +33,12 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
def __init__(self, params, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
params: the task configuration instance, which can be any of
dataclass, ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
......@@ -47,7 +46,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
self._logging_dir = logging_dir
@property
def task_config(self) -> cfg.TaskConfig:
def task_config(self):
return self._task_config
@property
......@@ -55,7 +54,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return self._logging_dir
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
"""[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
......@@ -83,9 +82,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates model architecture.
"""[Optional] Creates model architecture.
Returns:
A model instance.
......@@ -128,7 +126,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def build_inputs(self,
params: cfg.DataConfig,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions.
......@@ -136,7 +134,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines.
params: hyperparams to create input pipelines, which can be any of
dataclass, ConfigDict, namedtuple, etc.
input_context: optional distribution input pipeline context.
Returns:
......
......@@ -39,9 +39,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def __init__(self,
config: ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
train: bool = True,
evaluate: bool = True,
model=None,
optimizer=None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models.
......@@ -49,12 +49,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
Args:
config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance.
model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
......@@ -65,8 +65,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._strategy = tf.distribute.get_strategy()
self._config = config
self._task = task
self._model = model or task.build_model()
self._model = model
if optimizer is None:
opt_factory = optimization.OptimizerFactory(
......
......@@ -54,15 +54,15 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}
})))
def create_test_trainer(self):
def create_test_trainer(self, config):
task = mock_task.MockTask()
trainer = trainer_lib.Trainer(self._config, task)
trainer = trainer_lib.Trainer(config, task, model=task.build_model())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer()
trainer = self.create_test_trainer(self._config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
......@@ -70,7 +70,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer()
trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['acc'], 5. * distribution.num_replicas_in_sync)
......@@ -93,8 +93,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
'type': 'constant'
}
})))
task = mock_task.MockTask()
trainer = trainer_lib.Trainer(config, task)
trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
......@@ -125,11 +124,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
task = mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config, task, checkpoint_exporter=ckpt_exporter)
config,
task,
model=task.build_model(),
checkpoint_exporter=ckpt_exporter)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(tf.io.gfile.exists(
os.path.join(model_dir, 'best_ckpt', 'info.json')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
if __name__ == '__main__':
......
......@@ -37,10 +37,7 @@ class BestCheckpointExporter:
together with orbit once this functionality is ready.
"""
def __init__(self,
export_dir: str,
metric_name: str,
metric_comp: str):
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Arguments:
......@@ -53,9 +50,8 @@ class BestCheckpointExporter:
self._metric_name = metric_name
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
......@@ -65,8 +61,8 @@ class BestCheckpointExporter:
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(
checkpoint, self._best_ckpt_logs, global_step)
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
global_step)
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
......@@ -77,10 +73,9 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
raise KeyError('best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
......@@ -126,22 +121,22 @@ class BestCheckpointExporter:
return os.path.join(self._export_dir, 'best_ckpt')
def maybe_create_best_ckpt_exporter(
params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
metric_comp)
else:
best_ckpt_exporter = None
logging.info('Not exporting the best checkpoint. '
'data_dir: %s, export_subdir: %s, metric_name: %s',
data_dir, export_subdir, metric_name)
logging.info(
'Not exporting the best checkpoint. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
return best_ckpt_exporter
......@@ -174,10 +169,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
"""
with distribution_strategy.scope():
model = task.build_model()
trainer = train_utils.create_trainer(
params,
task,
model_dir,
model=model,
model_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
......@@ -200,12 +197,11 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (
save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if (
save_summary) else None,
summary_interval=params.trainer.summary_interval if (
save_summary) else None)
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
......@@ -219,10 +215,12 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
......
......@@ -32,18 +32,22 @@ from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions
def create_trainer(
params: config_definitions.ExperimentConfig,
task: base_task.Task,
model_dir: str,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None) -> base_trainer.Trainer:
def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
model_dir: str,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None) -> base_trainer.Trainer:
"""Create trainer."""
del model_dir
logging.info('Running default trainer.')
trainer = base_trainer.Trainer(
params, task, train=train, evaluate=evaluate,
params,
task,
train=train,
evaluate=evaluate,
model=model,
checkpoint_exporter=checkpoint_exporter)
return trainer
......@@ -129,8 +133,8 @@ def read_global_step_from_checkpoint(ckpt_file_path):
'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s',
global_step_restored, ckpt_file_path)
logging.info('get global_step %d from checkpoint %s', global_step_restored,
ckpt_file_path)
return global_step_restored
......@@ -143,8 +147,8 @@ def write_json_summary(log_dir, global_step, eval_metrics):
else:
serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s',
global_step, serializable_dict)
logging.info('Evaluation results at pretrain step %d: %s', global_step,
serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment