Commit eb18c12c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331861386
parent 227bb207
......@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""TFM common training driver library."""
# pytype: disable=attribute-error
import copy
import json
import os
......@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout)
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
......
......@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
......@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode != 'eval':
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
......
......@@ -38,7 +38,7 @@ def create_trainer(
model_dir: str,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None):
checkpoint_exporter: Any = None) -> base_trainer.Trainer:
"""Create trainer."""
del model_dir
logging.info('Running default trainer.')
......
......@@ -189,7 +189,7 @@ class TrainerConfig(base_config.Config):
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 24 hrs.
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
......@@ -218,7 +218,7 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: int = 24 * 60 * 60
continuous_eval_timeout: int = 60 * 60
# Train/Eval routines.
train_steps: int = 0
validation_steps: Optional[int] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment