Commit 0ca9cd6a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 337251846
parent bdf8299a
...@@ -209,6 +209,8 @@ class MaskRCNNTask(cfg.TaskConfig): ...@@ -209,6 +209,8 @@ class MaskRCNNTask(cfg.TaskConfig):
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
gradient_clip_norm: float = 0.0 gradient_clip_norm: float = 0.0
per_category_metrics: bool = False per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs.
allowed_mask_class_ids: Optional[List[int]] = None
COCO_INPUT_PATH_BASE = 'coco' COCO_INPUT_PATH_BASE = 'coco'
......
...@@ -164,44 +164,69 @@ class COCOEvaluator(object): ...@@ -164,44 +164,69 @@ class COCOEvaluator(object):
metrics_dict[name] = metrics[i].astype(np.float32) metrics_dict[name] = metrics[i].astype(np.float32)
# Adds metrics per category. # Adds metrics per category.
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'): if self._per_category_metrics:
metrics_dict.update(self._retrieve_per_category_metrics(coco_eval))
if self._include_mask:
metrics_dict.update(self._retrieve_per_category_metrics(
mcoco_eval, prefix='mask'))
return metrics_dict
def _retrieve_per_category_metrics(self, coco_eval, prefix=''):
"""Retrieves and per-category metrics and retuns them in a dict.
Args:
coco_eval: a cocoeval.COCOeval object containing evaluation data.
prefix: str, A string used to prefix metric names.
Returns:
metrics_dict: A dictionary with per category metrics.
"""
metrics_dict = {}
if prefix:
prefix = prefix + ' '
if hasattr(coco_eval, 'category_stats'):
for category_index, category_id in enumerate(coco_eval.params.catIds): for category_index, category_id in enumerate(coco_eval.params.catIds):
metrics_dict['Precision mAP ByCategory/{}'.format( metrics_dict[prefix + 'Precision mAP ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[0][category_index].astype( category_id)] = coco_eval.category_stats[0][category_index].astype(
np.float32) np.float32)
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format( metrics_dict[prefix + 'Precision mAP ByCategory@50IoU/{}'.format(
category_id)] = coco_eval.category_stats[1][category_index].astype( category_id)] = coco_eval.category_stats[1][category_index].astype(
np.float32) np.float32)
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format( metrics_dict[prefix + 'Precision mAP ByCategory@75IoU/{}'.format(
category_id)] = coco_eval.category_stats[2][category_index].astype( category_id)] = coco_eval.category_stats[2][category_index].astype(
np.float32) np.float32)
metrics_dict['Precision mAP ByCategory (small) /{}'.format( metrics_dict[prefix +'Precision mAP ByCategory (small) /{}'.format(
category_id)] = coco_eval.category_stats[3][category_index].astype( category_id)] = coco_eval.category_stats[3][category_index].astype(
np.float32) np.float32)
metrics_dict['Precision mAP ByCategory (medium) /{}'.format( metrics_dict[prefix +'Precision mAP ByCategory (medium) /{}'.format(
category_id)] = coco_eval.category_stats[4][category_index].astype( category_id)] = coco_eval.category_stats[4][category_index].astype(
np.float32) np.float32)
metrics_dict['Precision mAP ByCategory (large) /{}'.format( metrics_dict[prefix + 'Precision mAP ByCategory (large) /{}'.format(
category_id)] = coco_eval.category_stats[5][category_index].astype( category_id)] = coco_eval.category_stats[5][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR@1 ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR@1 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[6][category_index].astype( category_id)] = coco_eval.category_stats[6][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR@10 ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR@10 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[7][category_index].astype( category_id)] = coco_eval.category_stats[7][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR@100 ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR@100 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[8][category_index].astype( category_id)] = coco_eval.category_stats[8][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR (small) ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR (small) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[9][category_index].astype( category_id)] = coco_eval.category_stats[9][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR (medium) ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR (medium) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[10][category_index].astype( category_id)] = coco_eval.category_stats[10][category_index].astype(
np.float32) np.float32)
metrics_dict['Recall AR (large) ByCategory/{}'.format( metrics_dict[prefix + 'Recall AR (large) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[11][category_index].astype( category_id)] = coco_eval.category_stats[11][category_index].astype(
np.float32) np.float32)
return metrics_dict return metrics_dict
def _process_predictions(self, predictions): def _process_predictions(self, predictions):
......
...@@ -29,6 +29,29 @@ from official.vision.beta.losses import maskrcnn_losses ...@@ -29,6 +29,29 @@ from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
def zero_out_disallowed_class_ids(batch_class_ids, allowed_class_ids):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
batch_class_ids: A [batch_size, num_instances] int tensor of input
class IDs.
allowed_class_ids: A python list of class IDs which we want to allow.
Returns:
filtered_class_ids: A [batch_size, num_instances] int tensor with any
class ID not in allowed_class_ids set to 0.
"""
allowed_class_ids = tf.constant(allowed_class_ids,
dtype=batch_class_ids.dtype)
match_ids = (batch_class_ids[:, :, tf.newaxis] ==
allowed_class_ids[tf.newaxis, tf.newaxis, :])
match_ids = tf.reduce_any(match_ids, axis=2)
return tf.where(match_ids, batch_class_ids, tf.zeros_like(batch_class_ids))
@task_factory.register_task_cls(exp_cfg.MaskRCNNTask) @task_factory.register_task_cls(exp_cfg.MaskRCNNTask)
class MaskRCNNTask(base_task.Task): class MaskRCNNTask(base_task.Task):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -154,11 +177,17 @@ class MaskRCNNTask(base_task.Task): ...@@ -154,11 +177,17 @@ class MaskRCNNTask(base_task.Task):
if params.model.include_mask: if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss() mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
mask_class_targets = outputs['mask_class_targets']
if self._task_config.allowed_mask_class_ids is not None:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets = zero_out_disallowed_class_ids(
mask_class_targets, self._task_config.allowed_mask_class_ids)
mask_loss = tf.reduce_mean( mask_loss = tf.reduce_mean(
mask_loss_fn( mask_loss_fn(
outputs['mask_outputs'], outputs['mask_outputs'],
outputs['mask_targets'], outputs['mask_targets'],
outputs['mask_class_targets'])) mask_class_targets))
else: else:
mask_loss = 0.0 mask_loss = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment