Commit db4acd91 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

code clean up.

PiperOrigin-RevId: 342179055
parent 292ec4cb
...@@ -25,9 +25,9 @@ from absl import logging ...@@ -25,9 +25,9 @@ from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import train_utils
from official.core import base_task from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.core import train_utils
class BestCheckpointExporter: class BestCheckpointExporter:
...@@ -172,7 +172,6 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -172,7 +172,6 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
trainer = train_utils.create_trainer( trainer = train_utils.create_trainer(
params, params,
task, task,
model_dir=model_dir,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import json import json
import os import os
import pprint import pprint
from typing import Any, List, Optional from typing import Any, List
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -36,10 +36,8 @@ def create_trainer(params: config_definitions.ExperimentConfig, ...@@ -36,10 +36,8 @@ def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task, task: base_task.Task,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
checkpoint_exporter: Any = None, checkpoint_exporter: Any = None) -> base_trainer.Trainer:
model_dir: Optional[str] = None) -> base_trainer.Trainer:
"""Create trainer.""" """Create trainer."""
del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model() model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime) optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
......
...@@ -19,7 +19,6 @@ from absl import app ...@@ -19,7 +19,6 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
from official.core import train_utils
from official.common import distribute_utils from official.common import distribute_utils
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
...@@ -27,6 +26,7 @@ from official.common import registry_imports ...@@ -27,6 +26,7 @@ from official.common import registry_imports
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils
from official.modeling import performance from official.modeling import performance
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -19,7 +19,6 @@ from absl import app ...@@ -19,7 +19,6 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
from official.core import train_utils
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
...@@ -27,6 +26,7 @@ from official.common import distribute_utils ...@@ -27,6 +26,7 @@ from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils
from official.modeling import performance from official.modeling import performance
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment