"examples/vscode:/vscode.git/clone" did not exist on "0b2efc2adc8c5e01c1a4ef3a5ec6c9f5bac684be"
Commit 18b1714c authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Removes cls_outputs and box_outputs from postprocessing results.

PiperOrigin-RevId: 276586434
parent 8e91adaf
...@@ -95,9 +95,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -95,9 +95,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
inputs, labels = inputs inputs, labels = inputs
model_outputs = model(inputs, training=False) model_outputs = model(inputs, training=False)
if self._predict_post_process_fn: if self._predict_post_process_fn:
labels, model_outputs = self._predict_post_process_fn( labels, prediction_outputs = self._predict_post_process_fn(
labels, model_outputs) labels, model_outputs)
return labels, model_outputs return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2( labels, outputs = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),)) _test_step_fn, args=(next(iterator),))
......
...@@ -40,10 +40,8 @@ class COCOMetrics(object): ...@@ -40,10 +40,8 @@ class COCOMetrics(object):
self._evaluator = eval_factory.evaluator_generator(params.eval) self._evaluator = eval_factory.evaluator_generator(params.eval)
def update_state(self, y_true, y_pred): def update_state(self, y_true, y_pred):
labels, outputs = y_true, y_pred labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
labels = tf.nest.map_structure(lambda x: x.numpy(), labels)
outputs = tf.nest.map_structure(lambda x: x.numpy(), outputs)
groundtruths = {} groundtruths = {}
predictions = {} predictions = {}
for key, val in outputs.items(): for key, val in outputs.items():
...@@ -161,14 +159,16 @@ class RetinanetModel(base_model.Model): ...@@ -161,14 +159,16 @@ class RetinanetModel(base_model.Model):
boxes, scores, classes, valid_detections = self._generate_detections_fn( boxes, scores, classes, valid_detections = self._generate_detections_fn(
inputs=(outputs['box_outputs'], outputs['cls_outputs'], inputs=(outputs['box_outputs'], outputs['cls_outputs'],
labels['anchor_boxes'], labels['image_info'][:, 1:2, :])) labels['anchor_boxes'], labels['image_info'][:, 1:2, :]))
outputs.update({ # Discards the old output tensors to save memory. The `cls_outputs` and
# `box_outputs` are pretty big and could potentiall lead to memory issue.
outputs = {
'source_id': labels['groundtruths']['source_id'], 'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info'], 'image_info': labels['image_info'],
'num_detections': valid_detections, 'num_detections': valid_detections,
'detection_boxes': boxes, 'detection_boxes': boxes,
'detection_classes': classes, 'detection_classes': classes,
'detection_scores': scores, 'detection_scores': scores,
}) }
if 'groundtruths' in labels: if 'groundtruths' in labels:
labels['source_id'] = labels['groundtruths']['source_id'] labels['source_id'] = labels['groundtruths']['source_id']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment