"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "2e2620476ce4b6ccd684cd0275f1747c9bb910c5"
Commit e8e987a6 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

internal change

PiperOrigin-RevId: 336042908
parent bd73276e
...@@ -169,11 +169,9 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -169,11 +169,9 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
model = task.build_model()
trainer = train_utils.create_trainer( trainer = train_utils.create_trainer(
params, params,
task, task,
model=model,
model_dir=model_dir, model_dir=model_dir,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval,
......
...@@ -34,7 +34,6 @@ from official.modeling.hyperparams import config_definitions ...@@ -34,7 +34,6 @@ from official.modeling.hyperparams import config_definitions
def create_trainer(params: config_definitions.ExperimentConfig, def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task, task: base_task.Task,
model: tf.keras.Model,
model_dir: str, model_dir: str,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
...@@ -42,6 +41,7 @@ def create_trainer(params: config_definitions.ExperimentConfig, ...@@ -42,6 +41,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer.""" """Create trainer."""
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model()
trainer = base_trainer.Trainer( trainer = base_trainer.Trainer(
params, params,
task, task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment