Commit ea30986e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 337251846
parent a3369cd4
......@@ -209,6 +209,8 @@ class MaskRCNNTask(cfg.TaskConfig):
annotation_file: Optional[str] = None
gradient_clip_norm: float = 0.0
per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs.
allowed_mask_class_ids: Optional[List[int]] = None
COCO_INPUT_PATH_BASE = 'coco'
......
......@@ -164,44 +164,69 @@ class COCOEvaluator(object):
metrics_dict[name] = metrics[i].astype(np.float32)
# Adds metrics per category.
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'):
if self._per_category_metrics:
metrics_dict.update(self._retrieve_per_category_metrics(coco_eval))
if self._include_mask:
metrics_dict.update(self._retrieve_per_category_metrics(
mcoco_eval, prefix='mask'))
return metrics_dict
def _retrieve_per_category_metrics(self, coco_eval, prefix=''):
"""Retrieves and per-category metrics and retuns them in a dict.
Args:
coco_eval: a cocoeval.COCOeval object containing evaluation data.
prefix: str, A string used to prefix metric names.
Returns:
metrics_dict: A dictionary with per category metrics.
"""
metrics_dict = {}
if prefix:
prefix = prefix + ' '
if hasattr(coco_eval, 'category_stats'):
for category_index, category_id in enumerate(coco_eval.params.catIds):
metrics_dict['Precision mAP ByCategory/{}'.format(
metrics_dict[prefix + 'Precision mAP ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[0][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format(
metrics_dict[prefix + 'Precision mAP ByCategory@50IoU/{}'.format(
category_id)] = coco_eval.category_stats[1][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format(
metrics_dict[prefix + 'Precision mAP ByCategory@75IoU/{}'.format(
category_id)] = coco_eval.category_stats[2][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (small) /{}'.format(
metrics_dict[prefix +'Precision mAP ByCategory (small) /{}'.format(
category_id)] = coco_eval.category_stats[3][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (medium) /{}'.format(
metrics_dict[prefix +'Precision mAP ByCategory (medium) /{}'.format(
category_id)] = coco_eval.category_stats[4][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (large) /{}'.format(
metrics_dict[prefix + 'Precision mAP ByCategory (large) /{}'.format(
category_id)] = coco_eval.category_stats[5][category_index].astype(
np.float32)
metrics_dict['Recall AR@1 ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR@1 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[6][category_index].astype(
np.float32)
metrics_dict['Recall AR@10 ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR@10 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[7][category_index].astype(
np.float32)
metrics_dict['Recall AR@100 ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR@100 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[8][category_index].astype(
np.float32)
metrics_dict['Recall AR (small) ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR (small) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[9][category_index].astype(
np.float32)
metrics_dict['Recall AR (medium) ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR (medium) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[10][category_index].astype(
np.float32)
metrics_dict['Recall AR (large) ByCategory/{}'.format(
metrics_dict[prefix + 'Recall AR (large) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[11][category_index].astype(
np.float32)
return metrics_dict
def _process_predictions(self, predictions):
......
......@@ -29,6 +29,29 @@ from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory
def zero_out_disallowed_class_ids(batch_class_ids, allowed_class_ids):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
batch_class_ids: A [batch_size, num_instances] int tensor of input
class IDs.
allowed_class_ids: A python list of class IDs which we want to allow.
Returns:
filtered_class_ids: A [batch_size, num_instances] int tensor with any
class ID not in allowed_class_ids set to 0.
"""
allowed_class_ids = tf.constant(allowed_class_ids,
dtype=batch_class_ids.dtype)
match_ids = (batch_class_ids[:, :, tf.newaxis] ==
allowed_class_ids[tf.newaxis, tf.newaxis, :])
match_ids = tf.reduce_any(match_ids, axis=2)
return tf.where(match_ids, batch_class_ids, tf.zeros_like(batch_class_ids))
@task_factory.register_task_cls(exp_cfg.MaskRCNNTask)
class MaskRCNNTask(base_task.Task):
"""A single-replica view of training procedure.
......@@ -154,11 +177,17 @@ class MaskRCNNTask(base_task.Task):
if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
mask_class_targets = outputs['mask_class_targets']
if self._task_config.allowed_mask_class_ids is not None:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets = zero_out_disallowed_class_ids(
mask_class_targets, self._task_config.allowed_mask_class_ids)
mask_loss = tf.reduce_mean(
mask_loss_fn(
outputs['mask_outputs'],
outputs['mask_targets'],
outputs['mask_class_targets']))
mask_class_targets))
else:
mask_loss = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment