Commit 91a1ce9b authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Code cleanup.

PiperOrigin-RevId: 283837279
parent 5b25005c
......@@ -42,6 +42,35 @@ from official.vision.detection.evaluation import coco_utils
from official.vision.detection.utils import class_utils
class MetricWrapper(object):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
def __init__(self, evaluator):
self._evaluator = evaluator
def update_state(self, y_true, y_pred):
labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
groundtruths = {}
predictions = {}
for key, val in outputs.items():
if isinstance(val, tuple):
val = np.concatenate(val)
predictions[key] = val
for key, val in labels.items():
if isinstance(val, tuple):
val = np.concatenate(val)
groundtruths[key] = val
self._evaluator.update(predictions, groundtruths)
def result(self):
return self._evaluator.evaluate()
def reset_states(self):
return self._evaluator.reset()
class COCOEvaluator(object):
"""COCO evaluation metric class."""
......
......@@ -32,4 +32,4 @@ def evaluator_generator(params):
else:
raise ValueError('Evaluator %s is not supported.' % params.type)
return evaluator
return coco_evaluator.MetricWrapper(evaluator)
......@@ -32,35 +32,6 @@ from official.vision.detection.modeling.architecture import factory
from official.vision.detection.ops import postprocess_ops
class COCOMetrics(object):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
def __init__(self, params):
self._evaluator = eval_factory.evaluator_generator(params.eval)
def update_state(self, y_true, y_pred):
labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
groundtruths = {}
predictions = {}
for key, val in outputs.items():
if isinstance(val, tuple):
val = np.concatenate(val)
predictions[key] = val
for key, val in labels.items():
if isinstance(val, tuple):
val = np.concatenate(val)
groundtruths[key] = val
self._evaluator.update(predictions, groundtruths)
def result(self):
return self._evaluator.evaluate()
def reset_states(self):
return self._evaluator.reset()
class RetinanetModel(base_model.Model):
"""RetinaNet model function."""
......@@ -97,6 +68,11 @@ class RetinanetModel(base_model.Model):
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
def build_outputs(self, inputs, mode):
# If the input image is transposed (from NHWC to HWCN), we need to revert it
# back to the original shape before it's used in the computation.
if self._transpose_input:
inputs = tf.transpose(inputs, [3, 0, 1, 2])
backbone_features = self._backbone_fn(
inputs, is_training=(mode == mode_keys.TRAIN))
fpn_features = self._fpn_fn(
......@@ -192,4 +168,4 @@ class RetinanetModel(base_model.Model):
return labels, outputs
def eval_metrics(self):
return COCOMetrics(self._params)
return eval_factory.evaluator_generator(self._params.eval)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment