Commit acea25b9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 295849975
parent f3600cd1
...@@ -64,7 +64,7 @@ class SummaryWriter(object): ...@@ -64,7 +64,7 @@ class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics. """Simple SummaryWriter for writing dictionary of metrics.
Attributes: Attributes:
_writer: The tf.SummaryWriter. writer: The tf.SummaryWriter.
""" """
def __init__(self, model_dir: Text, name: Text): def __init__(self, model_dir: Text, name: Text):
...@@ -74,7 +74,7 @@ class SummaryWriter(object): ...@@ -74,7 +74,7 @@ class SummaryWriter(object):
model_dir: the model folder path. model_dir: the model folder path.
name: the summary subfolder name. name: the summary subfolder name.
""" """
self._writer = tf.summary.create_file_writer(os.path.join(model_dir, name)) self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
def __call__(self, metrics: Union[Dict[Text, float], float], step: int): def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer. """Write metrics to summary with the given writer.
...@@ -88,10 +88,10 @@ class SummaryWriter(object): ...@@ -88,10 +88,10 @@ class SummaryWriter(object):
logging.warning('Warning: summary writer prefer metrics as dictionary.') logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics} metrics = {'metric': metrics}
with self._writer.as_default(): with self.writer.as_default():
for k, v in metrics.items(): for k, v in metrics.items():
tf.summary.scalar(k, v, step=step) tf.summary.scalar(k, v, step=step)
self._writer.flush() self.writer.flush()
class DistributedExecutor(object): class DistributedExecutor(object):
...@@ -122,6 +122,9 @@ class DistributedExecutor(object): ...@@ -122,6 +122,9 @@ class DistributedExecutor(object):
self._strategy = strategy self._strategy = strategy
self._checkpoint_name = 'ctl_step_{step}.ckpt' self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host self._is_multi_host = is_multi_host
self.train_summary_writer = None
self.eval_summary_writer = None
self.global_train_step = None
@property @property
def checkpoint_name(self): def checkpoint_name(self):
...@@ -395,7 +398,10 @@ class DistributedExecutor(object): ...@@ -395,7 +398,10 @@ class DistributedExecutor(object):
eval_metric = eval_metric_fn() eval_metric = eval_metric_fn()
train_metric = train_metric_fn() train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train') train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
self.train_summary_writer = train_summary_writer.writer
test_summary_writer = summary_writer_fn(model_dir, 'eval_test') test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer
# Continue training loop. # Continue training loop.
train_step = self._create_train_step( train_step = self._create_train_step(
...@@ -406,6 +412,7 @@ class DistributedExecutor(object): ...@@ -406,6 +412,7 @@ class DistributedExecutor(object):
metric=train_metric) metric=train_metric)
test_step = None test_step = None
if eval_input_fn and eval_metric: if eval_input_fn and eval_metric:
self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric) test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started') logging.info('Training started')
...@@ -549,6 +556,7 @@ class DistributedExecutor(object): ...@@ -549,6 +556,7 @@ class DistributedExecutor(object):
return True return True
summary_writer = summary_writer_fn(model_dir, 'eval') summary_writer = summary_writer_fn(model_dir, 'eval')
self.eval_summary_writer = summary_writer.writer
# Read checkpoints from the given model directory # Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses. # until `eval_timeout` seconds elapses.
...@@ -615,6 +623,7 @@ class DistributedExecutor(object): ...@@ -615,6 +623,7 @@ class DistributedExecutor(object):
'checkpoint', checkpoint_path) 'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path) checkpoint.restore(checkpoint_path)
self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step, eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator) eval_metric, eval_iterator)
......
...@@ -70,6 +70,9 @@ RETINANET_CFG = { ...@@ -70,6 +70,9 @@ RETINANET_CFG = {
'val_json_file': '', 'val_json_file': '',
'eval_file_pattern': '', 'eval_file_pattern': '',
'input_sharding': True, 'input_sharding': True,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
}, },
'predict': { 'predict': {
'predict_batch_size': 8, 'predict_batch_size': 8,
......
...@@ -25,6 +25,7 @@ import os ...@@ -25,6 +25,7 @@ import os
import json import json
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.modeling.training import distributed_executor as executor from official.modeling.training import distributed_executor as executor
from official.vision.detection.utils import box_utils
class DetectionDistributedExecutor(executor.DistributedExecutor): class DetectionDistributedExecutor(executor.DistributedExecutor):
...@@ -38,13 +39,19 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -38,13 +39,19 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables_filter=None, trainable_variables_filter=None,
**kwargs): **kwargs):
super(DetectionDistributedExecutor, self).__init__(**kwargs) super(DetectionDistributedExecutor, self).__init__(**kwargs)
params = kwargs['params']
if predict_post_process_fn: if predict_post_process_fn:
assert callable(predict_post_process_fn) assert callable(predict_post_process_fn)
if trainable_variables_filter: if trainable_variables_filter:
assert callable(trainable_variables_filter) assert callable(trainable_variables_filter)
self._predict_post_process_fn = predict_post_process_fn self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter self._trainable_variables_filter = trainable_variables_filter
self.eval_steps = tf.Variable(
0,
trainable=False,
dtype=tf.int32,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
def _create_replicated_step(self, def _create_replicated_step(self,
strategy, strategy,
...@@ -90,24 +97,41 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -90,24 +97,41 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
"""Creates a distributed test step.""" """Creates a distributed test step."""
@tf.function @tf.function
def test_step(iterator): def test_step(iterator, eval_steps):
"""Calculates evaluation metrics on distributed devices.""" """Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs): def _test_step_fn(inputs, eval_steps):
"""Replicated accuracy calculation.""" """Replicated accuracy calculation."""
inputs, labels = inputs inputs, labels = inputs
model_outputs = model(inputs, training=False) model_outputs = model(inputs, training=False)
if self._predict_post_process_fn: if self._predict_post_process_fn:
labels, prediction_outputs = self._predict_post_process_fn( labels, prediction_outputs = self._predict_post_process_fn(
labels, model_outputs) labels, model_outputs)
num_remaining_visualizations = (
self._params.eval.num_images_to_visualize - eval_steps)
# If there are remaining number of visualizations that needs to be
# done, add next batch outputs for visualization.
#
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
# write correct slice of outputs to summary file.
if num_remaining_visualizations > 0:
box_utils.visualize_bounding_boxes(
inputs, prediction_outputs['detection_boxes'],
self.global_train_step, self.eval_summary_writer)
return labels, prediction_outputs return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2( labels, outputs = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),)) _test_step_fn, args=(
next(iterator),
eval_steps,
))
outputs = tf.nest.map_structure(strategy.experimental_local_results, outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs) outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results, labels = tf.nest.map_structure(strategy.experimental_local_results,
labels) labels)
eval_steps.assign_add(self._params.eval.batch_size)
return labels, outputs return labels, outputs
return test_step return test_step
...@@ -115,6 +139,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -115,6 +139,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
def _run_evaluation(self, test_step, current_training_step, metric, def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator): test_iterator):
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics."""
self.eval_steps.assign(0)
if not test_iterator or not metric: if not test_iterator or not metric:
logging.warning( logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.', 'Both test_iterator (%s) and metrics (%s) must not be None.',
...@@ -123,7 +148,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -123,7 +148,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
logging.info('Running evaluation after step: %s.', current_training_step) logging.info('Running evaluation after step: %s.', current_training_step)
while True: while True:
try: try:
labels, outputs = test_step(test_iterator) labels, outputs = test_step(test_iterator, self.eval_steps)
if metric: if metric:
metric.update_state(labels, outputs) metric.update_state(labels, outputs)
except (StopIteration, tf.errors.OutOfRangeError): except (StopIteration, tf.errors.OutOfRangeError):
......
...@@ -239,4 +239,5 @@ def main(argv): ...@@ -239,4 +239,5 @@ def main(argv):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
tf.config.set_soft_device_placement(True)
app.run(main) app.run(main)
...@@ -26,6 +26,22 @@ EPSILON = 1e-8 ...@@ -26,6 +26,22 @@ EPSILON = 1e-8
BBOX_XFORM_CLIP = np.log(1000. / 16.) BBOX_XFORM_CLIP = np.log(1000. / 16.)
def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer):
"""Records subset of evaluation images with bounding boxes."""
image_shape = tf.shape(images[0])
image_height = tf.cast(image_shape[0], tf.float32)
image_width = tf.cast(image_shape[1], tf.float32)
normalized_boxes = normalize_boxes(box_outputs, [image_height, image_width])
bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
image_summary = tf.image.draw_bounding_boxes(images, normalized_boxes,
bounding_box_color)
with summary_writer.as_default():
tf.summary.image('bounding_box_summary', image_summary, step=step)
summary_writer.flush()
def yxyx_to_xywh(boxes): def yxyx_to_xywh(boxes):
"""Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height. """Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment