Commit db4acd91 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

code clean up.

PiperOrigin-RevId: 342179055
parent 292ec4cb
......@@ -25,9 +25,9 @@ from absl import logging
import orbit
import tensorflow as tf
from official.core import train_utils
from official.core import base_task
from official.core import config_definitions
from official.core import train_utils
class BestCheckpointExporter:
......@@ -172,7 +172,6 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
trainer = train_utils.create_trainer(
params,
task,
model_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
......
......@@ -18,7 +18,7 @@
import json
import os
import pprint
from typing import Any, List, Optional
from typing import Any, List
from absl import logging
import dataclasses
......@@ -36,10 +36,8 @@ def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None,
model_dir: Optional[str] = None) -> base_trainer.Trainer:
checkpoint_exporter: Any = None) -> base_trainer.Trainer:
"""Create trainer."""
del model_dir
logging.info('Running default trainer.')
model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
......
......@@ -19,7 +19,6 @@ from absl import app
from absl import flags
import gin
from official.core import train_utils
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
......@@ -27,6 +26,7 @@ from official.common import registry_imports
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
FLAGS = flags.FLAGS
......
......@@ -19,7 +19,6 @@ from absl import app
from absl import flags
import gin
from official.core import train_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
......@@ -27,6 +26,7 @@ from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment