Commit 90810a0e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342910130
parent a01346f0
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# ============================================================================== # ==============================================================================
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import copy
import json
import os import os
from typing import Any, Mapping, Tuple from typing import Any, Mapping, Tuple
...@@ -29,96 +27,7 @@ from official.core import base_task ...@@ -29,96 +27,7 @@ from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
global_step)
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError('best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items():
eval_logs_ext[name] = str(orbit.utils.get_value(value))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.rmtree(file_to_remove)
checkpoint.save(self.best_ckpt_path)
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt')
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig, def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
...@@ -129,8 +38,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig, ...@@ -129,8 +38,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
metric_comp = params.trainer.best_checkpoint_metric_comp metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name: if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir) best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name, best_ckpt_exporter = BestCheckpointExporter(
metric_comp) best_ckpt_dir, metric_name, metric_comp)
else: else:
best_ckpt_exporter = None best_ckpt_exporter = None
logging.info( logging.info(
......
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Training utils.""" """Training utils."""
import copy
import json import json
import os import os
import pprint import pprint
from typing import Any, List from typing import List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -32,16 +33,109 @@ from official.core import exp_factory ...@@ -32,16 +33,109 @@ from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
global_step)
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError('best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items():
eval_logs_ext[name] = str(orbit.utils.get_value(value))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.rmtree(file_to_remove)
checkpoint.save(self.best_ckpt_path)
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt')
@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig, def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task, task: base_task.Task,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
checkpoint_exporter: Any = None) -> base_trainer.Trainer: checkpoint_exporter: Optional[BestCheckpointExporter] = None,
trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
"""Create trainer.""" """Create trainer."""
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model() model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime) optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
trainer = base_trainer.Trainer( return trainer_cls(
params, params,
task, task,
model=model, model=model,
...@@ -49,7 +143,6 @@ def create_trainer(params: config_definitions.ExperimentConfig, ...@@ -49,7 +143,6 @@ def create_trainer(params: config_definitions.ExperimentConfig,
train=train, train=train,
evaluate=evaluate, evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter) checkpoint_exporter=checkpoint_exporter)
return trainer
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment