Commit d3320242 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304643558
parent edcb2146
...@@ -19,7 +19,6 @@ from __future__ import division ...@@ -19,7 +19,6 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import json
import os import os
from absl import flags from absl import flags
...@@ -32,6 +31,7 @@ import tensorflow as tf ...@@ -32,6 +31,7 @@ import tensorflow as tf
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -330,10 +330,12 @@ class DistributedExecutor(object): ...@@ -330,10 +330,12 @@ class DistributedExecutor(object):
eval_metric_fn = eval_metric_fn or _no_metric eval_metric_fn = eval_metric_fn or _no_metric
if custom_callbacks and iterations_per_loop != 1: if custom_callbacks and iterations_per_loop != 1:
logging.error( logging.warning(
'It is sematically wrong to run callbacks when ' 'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)', iterations_per_loop) 'iterations_per_loop is not one (%s)', iterations_per_loop)
custom_callbacks = custom_callbacks or []
def _run_callbacks_on_batch_begin(batch): def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step.""" """Runs custom callbacks at the start of every step."""
if not custom_callbacks: if not custom_callbacks:
...@@ -402,6 +404,11 @@ class DistributedExecutor(object): ...@@ -402,6 +404,11 @@ class DistributedExecutor(object):
test_summary_writer = summary_writer_fn(model_dir, 'eval_test') test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer self.eval_summary_writer = test_summary_writer.writer
# Use training summary writer in TimeHistory if it's in use
for cb in custom_callbacks:
if isinstance(cb, keras_utils.TimeHistory):
cb.summary_writer = self.train_summary_writer
# Continue training loop. # Continue training loop.
train_step = self._create_train_step( train_step = self._create_train_step(
strategy=strategy, strategy=strategy,
...@@ -422,11 +429,12 @@ class DistributedExecutor(object): ...@@ -422,11 +429,12 @@ class DistributedExecutor(object):
_run_callbacks_on_batch_begin(current_step) _run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator, train_loss = train_step(train_iterator,
tf.convert_to_tensor(num_steps, dtype=tf.int32)) tf.convert_to_tensor(num_steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += num_steps current_step += num_steps
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float), train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss) train_loss)
_run_callbacks_on_batch_end(current_step - 1)
if not isinstance(train_loss, dict): if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss} train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']): if np.isnan(train_loss['total_loss']):
...@@ -493,6 +501,9 @@ class DistributedExecutor(object): ...@@ -493,6 +501,9 @@ class DistributedExecutor(object):
test_summary_writer( test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations) metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close()
self.eval_summary_writer.close()
return train_loss, eval_metric_result return train_loss, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric, def _run_evaluation(self, test_step, current_training_step, metric,
......
...@@ -35,10 +35,12 @@ from official.vision.detection.dataloader import input_reader ...@@ -35,10 +35,12 @@ from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory from official.vision.detection.modeling import factory as model_factory
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
flags_core.define_log_steps()
flags.DEFINE_bool( flags.DEFINE_bool(
'enable_xla', 'enable_xla',
...@@ -224,6 +226,17 @@ def run(callbacks=None): ...@@ -224,6 +226,17 @@ def run(callbacks=None):
mode=input_reader.ModeKeys.PREDICT_WITH_GT, mode=input_reader.ModeKeys.PREDICT_WITH_GT,
batch_size=params.eval.batch_size, batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples) num_examples=params.eval.eval_samples)
if callbacks is None:
callbacks = []
if FLAGS.log_steps:
callbacks.append(
keras_utils.TimeHistory(
batch_size=params.train.batch_size,
log_steps=FLAGS.log_steps,
))
return run_executor( return run_executor(
params, params,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment