# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging from typing import Sequence import torch from detectron2.layers.nms import batched_nms from detectron2.structures.instances import Instances from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer from densepose.vis.densepose import DensePoseResultsVisualizer from .base import CompoundVisualizer Scores = Sequence[float] def extract_scores_from_instances(instances: Instances, select=None): if instances.has("scores"): return instances.scores if select is None else instances.scores[select] return None def extract_boxes_xywh_from_instances(instances: Instances, select=None): if instances.has("pred_boxes"): boxes_xywh = instances.pred_boxes.tensor.clone() boxes_xywh[:, 2] -= boxes_xywh[:, 0] boxes_xywh[:, 3] -= boxes_xywh[:, 1] return boxes_xywh if select is None else boxes_xywh[select] return None def create_extractor(visualizer: object): """ Create an extractor for the provided visualizer """ if isinstance(visualizer, CompoundVisualizer): extractors = [create_extractor(v) for v in visualizer.visualizers] return CompoundExtractor(extractors) elif isinstance(visualizer, DensePoseResultsVisualizer): return DensePoseResultExtractor() elif isinstance(visualizer, ScoredBoundingBoxVisualizer): return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) elif isinstance(visualizer, BoundingBoxVisualizer): return extract_boxes_xywh_from_instances else: logger = logging.getLogger(__name__) logger.error(f"Could not create extractor for {visualizer}") return None class BoundingBoxExtractor(object): """ Extracts bounding boxes from instances """ def __call__(self, instances: Instances): boxes_xywh = extract_boxes_xywh_from_instances(instances) return boxes_xywh class ScoredBoundingBoxExtractor(object): """ Extracts bounding boxes from instances """ def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) boxes_xywh = extract_boxes_xywh_from_instances(instances) if (scores is None) or (boxes_xywh is None): return (boxes_xywh, scores) if select is not None: scores = scores[select] boxes_xywh = boxes_xywh[select] return (boxes_xywh, scores) class DensePoseResultExtractor(object): """ Extracts DensePose result from instances """ def __call__(self, instances: Instances, select=None): boxes_xywh = extract_boxes_xywh_from_instances(instances) if instances.has("pred_densepose") and (boxes_xywh is not None): dpout = instances.pred_densepose if select is not None: dpout = dpout[select] boxes_xywh = boxes_xywh[select] return dpout.to_result(boxes_xywh) else: return None class CompoundExtractor(object): """ Extracts data for CompoundVisualizer """ def __init__(self, extractors): self.extractors = extractors def __call__(self, instances: Instances, select=None): datas = [] for extractor in self.extractors: data = extractor(instances, select) datas.append(data) return datas class NmsFilteredExtractor(object): """ Extracts data in the format accepted by NmsFilteredVisualizer """ def __init__(self, extractor, iou_threshold): self.extractor = extractor self.iou_threshold = iou_threshold def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) boxes_xywh = extract_boxes_xywh_from_instances(instances) if boxes_xywh is None: return None select_local_idx = batched_nms( boxes_xywh, scores, torch.zeros(len(scores), dtype=torch.int32), iou_threshold=self.iou_threshold, ).squeeze() select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) select_local[select_local_idx] = True select = select_local if select is None else (select & select_local) return self.extractor(instances, select=select) class ScoreThresholdedExtractor(object): """ Extracts data in the format accepted by ScoreThresholdedVisualizer """ def __init__(self, extractor, min_score): self.extractor = extractor self.min_score = min_score def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) if scores is None: return None select_local = scores > self.min_score select = select_local if select is None else (select & select_local) data = self.extractor(instances, select=select) return data