Commit eb18c12c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331861386
parent 227bb207
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error
import copy import copy
import json import json
import os import os
...@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif mode == 'eval': elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps) controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval': elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously( controller.evaluate_continuously(
steps=params.trainer.validation_steps, steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout) timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
......
...@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
'train_steps': 10, 'train_steps': 10,
'validation_steps': 5, 'validation_steps': 5,
'validation_interval': 10, 'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': { 'optimizer_config': {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
...@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(logs) self.assertEmpty(logs)
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml'))) tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode != 'eval': if flag_mode == 'eval':
self.assertNotEmpty( return
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -38,7 +38,7 @@ def create_trainer( ...@@ -38,7 +38,7 @@ def create_trainer(
model_dir: str, model_dir: str,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
checkpoint_exporter: Any = None): checkpoint_exporter: Any = None) -> base_trainer.Trainer:
"""Create trainer.""" """Create trainer."""
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
......
...@@ -189,7 +189,7 @@ class TrainerConfig(base_config.Config): ...@@ -189,7 +189,7 @@ class TrainerConfig(base_config.Config):
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default is only used continuous_train_and_eval and continuous_eval modes. Default
value is 24 hrs. value is 1 hrs.
train_steps: number of train steps. train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset validation_steps: number of eval steps. If `None`, the entire eval dataset
is used. is used.
...@@ -218,7 +218,7 @@ class TrainerConfig(base_config.Config): ...@@ -218,7 +218,7 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager. # Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: int = 24 * 60 * 60 continuous_eval_timeout: int = 60 * 60
# Train/Eval routines. # Train/Eval routines.
train_steps: int = 0 train_steps: int = 0
validation_steps: Optional[int] = None validation_steps: Optional[int] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment