Commit c732df65 authored by limm's avatar limm
Browse files

push v0.1.3 version commit bd2ea47

parent 5b3792fc
Pipeline #706 failed with stages
in 0 seconds
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import logging
import os
import sys
from timeit import default_timer as timer
from typing import Any, ClassVar, Dict, List
import torch
from fvcore.common.file_io import PathManager
from detectron2.data.catalog import DatasetCatalog
from detectron2.utils.logger import setup_logger
from densepose.data.structures import DensePoseDataRelative
from densepose.utils.dbhelper import EntrySelector
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import BoundingBoxVisualizer
from densepose.vis.densepose import (
DensePoseDataCoarseSegmentationVisualizer,
DensePoseDataPointsIVisualizer,
DensePoseDataPointsUVisualizer,
DensePoseDataPointsVisualizer,
DensePoseDataPointsVVisualizer,
)
DOC = """Query DB - a tool to print / visualize data from a database
"""
LOGGER_NAME = "query_db"
logger = logging.getLogger(LOGGER_NAME)
_ACTION_REGISTRY: Dict[str, "Action"] = {}
class Action(object):
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
parser.add_argument(
"-v",
"--verbosity",
action="count",
help="Verbose mode. Multiple -v options increase the verbosity.",
)
def register_action(cls: type):
"""
Decorator for action classes to automate action registration
"""
global _ACTION_REGISTRY
_ACTION_REGISTRY[cls.COMMAND] = cls
return cls
class EntrywiseAction(Action):
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(EntrywiseAction, cls).add_arguments(parser)
parser.add_argument(
"dataset", metavar="<dataset>", help="Dataset name (e.g. densepose_coco_2014_train)"
)
parser.add_argument(
"selector",
metavar="<selector>",
help="Dataset entry selector in the form field1[:type]=value1[,"
"field2[:type]=value_min-value_max...] which selects all "
"entries from the dataset that satisfy the constraints",
)
parser.add_argument(
"--max-entries", metavar="N", help="Maximum number of entries to process", type=int
)
@classmethod
def execute(cls: type, args: argparse.Namespace):
dataset = setup_dataset(args.dataset)
entry_selector = EntrySelector.from_string(args.selector)
context = cls.create_context(args)
if args.max_entries is not None:
for _, entry in zip(range(args.max_entries), dataset):
if entry_selector(entry):
cls.execute_on_entry(entry, context)
else:
for entry in dataset:
if entry_selector(entry):
cls.execute_on_entry(entry, context)
@classmethod
def create_context(cls: type, args: argparse.Namespace) -> Dict[str, Any]:
context = {}
return context
@register_action
class PrintAction(EntrywiseAction):
"""
Print action that outputs selected entries to stdout
"""
COMMAND: ClassVar[str] = "print"
@classmethod
def add_parser(cls: type, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(cls.COMMAND, help="Output selected entries to stdout. ")
cls.add_arguments(parser)
parser.set_defaults(func=cls.execute)
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(PrintAction, cls).add_arguments(parser)
@classmethod
def execute_on_entry(cls: type, entry: Dict[str, Any], context: Dict[str, Any]):
import pprint
printer = pprint.PrettyPrinter(indent=2, width=200, compact=True)
printer.pprint(entry)
@register_action
class ShowAction(EntrywiseAction):
"""
Show action that visualizes selected entries on an image
"""
COMMAND: ClassVar[str] = "show"
VISUALIZERS: ClassVar[Dict[str, object]] = {
"dp_segm": DensePoseDataCoarseSegmentationVisualizer(),
"dp_i": DensePoseDataPointsIVisualizer(),
"dp_u": DensePoseDataPointsUVisualizer(),
"dp_v": DensePoseDataPointsVVisualizer(),
"dp_pts": DensePoseDataPointsVisualizer(),
"bbox": BoundingBoxVisualizer(),
}
@classmethod
def add_parser(cls: type, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
cls.add_arguments(parser)
parser.set_defaults(func=cls.execute)
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(ShowAction, cls).add_arguments(parser)
parser.add_argument(
"visualizations",
metavar="<visualizations>",
help="Comma separated list of visualizations, possible values: "
"[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
)
parser.add_argument(
"--output",
metavar="<image_file>",
default="output.png",
help="File name to save output to",
)
@classmethod
def execute_on_entry(cls: type, entry: Dict[str, Any], context: Dict[str, Any]):
import cv2
import numpy as np
image_fpath = PathManager.get_local_path(entry["file_name"])
image = cv2.imread(image_fpath, cv2.IMREAD_GRAYSCALE)
image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
datas = cls._extract_data_for_visualizers_from_entry(context["vis_specs"], entry)
visualizer = context["visualizer"]
image_vis = visualizer.visualize(image, datas)
entry_idx = context["entry_idx"] + 1
out_fname = cls._get_out_fname(entry_idx, context["out_fname"])
cv2.imwrite(out_fname, image_vis)
logger.info(f"Output saved to {out_fname}")
context["entry_idx"] += 1
@classmethod
def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
base, ext = os.path.splitext(fname_base)
return base + ".{0:04d}".format(entry_idx) + ext
@classmethod
def create_context(cls: type, args: argparse.Namespace) -> Dict[str, Any]:
vis_specs = args.visualizations.split(",")
visualizers = []
for vis_spec in vis_specs:
vis = cls.VISUALIZERS[vis_spec]
visualizers.append(vis)
context = {
"vis_specs": vis_specs,
"visualizer": CompoundVisualizer(visualizers),
"out_fname": args.output,
"entry_idx": 0,
}
return context
@classmethod
def _extract_data_for_visualizers_from_entry(
cls: type, vis_specs: List[str], entry: Dict[str, Any]
):
dp_list = []
bbox_list = []
for annotation in entry["annotations"]:
is_valid, _ = DensePoseDataRelative.validate_annotation(annotation)
if not is_valid:
continue
bbox = torch.as_tensor(annotation["bbox"])
bbox_list.append(bbox)
dp_data = DensePoseDataRelative(annotation)
dp_list.append(dp_data)
datas = []
for vis_spec in vis_specs:
datas.append(bbox_list if "bbox" == vis_spec else (bbox_list, dp_list))
return datas
def setup_dataset(dataset_name):
logger.info("Loading dataset {}".format(dataset_name))
start = timer()
dataset = DatasetCatalog.get(dataset_name)
stop = timer()
logger.info("Loaded dataset {} in {:.3f}s".format(dataset_name, stop - start))
return dataset
def create_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=DOC,
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
)
parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
subparsers = parser.add_subparsers(title="Actions")
for _, action in _ACTION_REGISTRY.items():
action.add_parser(subparsers)
return parser
def main():
parser = create_argument_parser()
args = parser.parse_args()
verbosity = args.verbosity if hasattr(args, "verbosity") else None
global logger
logger = setup_logger(name=LOGGER_NAME)
logger.setLevel(verbosity_to_level(verbosity))
args.func(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import torch
from detectron2.config import get_cfg
from detectron2.engine import default_setup
from detectron2.modeling import build_model
from densepose import add_dataset_category_config, add_densepose_config
_BASE_CONFIG_DIR = "configs"
_EVOLUTION_CONFIG_SUB_DIR = "evolution"
_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules"
_BASE_CONFIG_FILE_PREFIX = "Base-"
_CONFIG_FILE_EXT = ".yaml"
def _get_base_config_dir():
"""
Return the base directory for configurations
"""
return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR)
def _get_evolution_config_dir():
"""
Return the base directory for evolution configurations
"""
return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR)
def _get_quick_schedules_config_dir():
"""
Return the base directory for quick schedules configurations
"""
return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR)
def _collect_config_files(config_dir):
"""
Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory
"""
start = _get_base_config_dir()
results = []
for entry in os.listdir(config_dir):
path = os.path.join(config_dir, entry)
if not os.path.isfile(path):
continue
_, ext = os.path.splitext(entry)
if ext != _CONFIG_FILE_EXT:
continue
if entry.startswith(_BASE_CONFIG_FILE_PREFIX):
continue
config_file = os.path.relpath(path, start)
results.append(config_file)
return results
def get_config_files():
"""
Get all the configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_base_config_dir())
def get_evolution_config_files():
"""
Get all the evolution configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_evolution_config_dir())
def get_quick_schedules_config_files():
"""
Get all the quick schedules configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_quick_schedules_config_dir())
def _get_model_config(config_file):
"""
Load and return the configuration from the specified file (relative to the base configuration
directory)
"""
cfg = get_cfg()
add_dataset_category_config(cfg)
add_densepose_config(cfg)
path = os.path.join(_get_base_config_dir(), config_file)
cfg.merge_from_file(path)
if not torch.cuda.is_available():
cfg.MODEL_DEVICE = "cpu"
return cfg
def get_model(config_file):
"""
Get the model from the specified file (relative to the base configuration directory)
"""
cfg = _get_model_config(config_file)
return build_model(cfg)
def setup(config_file):
"""
Setup the configuration from the specified file (relative to the base configuration directory)
"""
cfg = _get_model_config(config_file)
cfg.freeze()
default_setup(cfg, {})
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from detectron2.structures import BitMasks, Boxes, Instances
from .common import get_model
# TODO(plabatut): Modularize detectron2 tests and re-use
def make_model_inputs(image, instances=None):
if instances is None:
return {"image": image}
return {"image": image, "instances": instances}
def make_empty_instances(h, w):
instances = Instances((h, w))
instances.gt_boxes = Boxes(torch.rand(0, 4))
instances.gt_classes = torch.tensor([]).to(dtype=torch.int64)
instances.gt_masks = BitMasks(torch.rand(0, h, w))
return instances
class ModelE2ETest(unittest.TestCase):
CONFIG_PATH = ""
def setUp(self):
self.model = get_model(self.CONFIG_PATH)
def _test_eval(self, sizes):
inputs = [make_model_inputs(torch.rand(3, size[0], size[1])) for size in sizes]
self.model.eval()
self.model(inputs)
class DensePoseRCNNE2ETest(ModelE2ETest):
CONFIG_PATH = "densepose_rcnn_R_101_FPN_s1x.yaml"
def test_empty_data(self):
self._test_eval([(200, 250), (200, 249)])
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from .common import (
get_config_files,
get_evolution_config_files,
get_quick_schedules_config_files,
setup,
)
class TestSetup(unittest.TestCase):
def _test_setup(self, config_file):
setup(config_file)
def test_setup_configs(self):
config_files = get_config_files()
for config_file in config_files:
self._test_setup(config_file)
def test_setup_evolution_configs(self):
config_files = get_evolution_config_files()
for config_file in config_files:
self._test_setup(config_file)
def test_setup_quick_schedules_configs(self):
config_files = get_quick_schedules_config_files()
for config_file in config_files:
self._test_setup(config_file)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from densepose.data.structures import normalized_coords_transform
class TestStructures(unittest.TestCase):
def test_normalized_coords_transform(self):
bbox = (32, 24, 288, 216)
x0, y0, w, h = bbox
xmin, ymin, xmax, ymax = x0, y0, x0 + w, y0 + h
f = normalized_coords_transform(*bbox)
# Top-left
expected_p, actual_p = (-1, -1), f((xmin, ymin))
self.assertEqual(expected_p, actual_p)
# Top-right
expected_p, actual_p = (1, -1), f((xmax, ymin))
self.assertEqual(expected_p, actual_p)
# Bottom-left
expected_p, actual_p = (-1, 1), f((xmin, ymax))
self.assertEqual(expected_p, actual_p)
# Bottom-right
expected_p, actual_p = (1, 1), f((xmax, ymax))
self.assertEqual(expected_p, actual_p)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DensePose Training Script.
This script is similar to the training script in detectron2/tools.
It is an example of how a user might use detectron2 for a new project.
"""
import logging
import os
from collections import OrderedDict
from fvcore.common.file_io import PathManager
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, verify_results
from detectron2.modeling import DatasetMapperTTA
from detectron2.utils.logger import setup_logger
from densepose import (
DensePoseCOCOEvaluator,
DensePoseGeneralizedRCNNWithTTA,
add_dataset_category_config,
add_densepose_config,
load_from_cfg,
)
from densepose.data import DatasetMapper, build_detection_test_loader, build_detection_train_loader
class Trainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg: CfgNode, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluators = [COCOEvaluator(dataset_name, cfg, True, output_folder)]
if cfg.MODEL.DENSEPOSE_ON:
evaluators.append(DensePoseCOCOEvaluator(dataset_name, True, output_folder))
return DatasetEvaluators(evaluators)
@classmethod
def build_test_loader(cls, cfg: CfgNode, dataset_name):
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
@classmethod
def build_train_loader(cls, cfg: CfgNode):
return build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
@classmethod
def test_with_TTA(cls, cfg: CfgNode, model):
logger = logging.getLogger("detectron2.trainer")
# In the end of training, run an evaluation with TTA
# Only support some R-CNN models.
logger.info("Running inference with test-time augmentation ...")
transform_data = load_from_cfg(cfg)
model = DensePoseGeneralizedRCNNWithTTA(cfg, model, transform_data, DatasetMapperTTA(cfg))
evaluators = [
cls.build_evaluator(
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
)
for name in cfg.DATASETS.TEST
]
res = cls.test(cfg, model, evaluators)
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
def setup(args):
cfg = get_cfg()
add_dataset_category_config(cfg)
add_densepose_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
# Setup logger for "densepose" module
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="densepose")
return cfg
def main(args):
cfg = setup(args)
# disable strict kwargs checking: allow one to specify path handle
# hints through kwargs, like timeout in DP evaluation
PathManager.set_strict_kwargs_checking(False)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
if cfg.TEST.AUG.ENABLED:
res.update(Trainer.test_with_TTA(cfg, model))
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
if cfg.TEST.AUG.ENABLED:
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
# PointRend: Image Segmentation as Rendering
Alexander Kirillov, Yuxin Wu, Kaiming He, Ross Girshick
[[`arXiv`](https://arxiv.org/abs/1912.08193)] [[`BibTeX`](#CitingPointRend)]
<div align="center">
<img src="https://alexander-kirillov.github.io/images/kirillov2019pointrend.jpg"/>
</div><br/>
In this repository, we release code for PointRend in Detectron2. PointRend can be flexibly applied to both instance and semantic segmentation tasks by building on top of existing state-of-the-art models.
## Installation
Install Detectron 2 following [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). You are ready to go!
## Quick start and visualization
This [Colab Notebook](https://colab.research.google.com/drive/1isGPL5h5_cKoPPhVL9XhMokRtHDvmMVL) tutorial contains examples of PointRend usage and visualizations of its point sampling stages.
## Training
To train a model with 8 GPUs run:
```bash
cd /path/to/detectron2/projects/PointRend
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --num-gpus 8
```
## Evaluation
Model evaluation can be done similarly:
```bash
cd /path/to/detectron2/projects/PointRend
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint
```
# Pretrained Models
## Instance Segmentation
#### COCO
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Mask<br/>head</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">lr<br/>sched</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mask<br/>AP</th>
<th valign="bottom">mask<br/>AP&ast;</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">1&times;</td>
<td align="center">224&times;224</td>
<td align="center">36.2</td>
<td align="center">39.7</td>
<td align="center">164254221</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/model_final_88c6f8.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/metrics.json">metrics</a></td>
</tr>
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">3&times;</td>
<td align="center">224&times;224</td>
<td align="center">38.3</td>
<td align="center">41.6</td>
<td align="center">164955410</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_3c3198.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/metrics.json">metrics</a></td>
</tr>
</tbody></table>
AP&ast; is COCO mask AP evaluated against the higher-quality LVIS annotations; see the paper for details. Run `python detectron2/datasets/prepare_cocofied_lvis.py` to prepare GT files for AP&ast; evaluation. Since LVIS annotations are not exhaustive `lvis-api` and not `cocoapi` should be used to evaluate AP&ast;.
#### Cityscapes
Cityscapes model is trained with ImageNet pretraining.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Mask<br/>head</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">lr<br/>sched</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mask<br/>AP</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">1&times;</td>
<td align="center">224&times;224</td>
<td align="center">35.9</td>
<td align="center">164255101</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/model_final_318a02.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/metrics.json">metrics</a></td>
</tr>
</tbody></table>
## Semantic Segmentation
#### Cityscapes
Cityscapes model is trained with ImageNet pretraining.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Method</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mIoU</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml">SemanticFPN + PointRend</a></td>
<td align="center">R101-FPN</td>
<td align="center">1024&times;2048</td>
<td align="center">78.6</td>
<td align="center">186480235</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/model_final_5f3665.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/metrics.json">metrics</a></td>
</tr>
</tbody></table>
## <a name="CitingPointRend"></a>Citing PointRend
If you use PointRend, please use the following BibTeX entry.
```BibTeX
@InProceedings{kirillov2019pointrend,
title={{PointRend}: Image Segmentation as Rendering},
author={Alexander Kirillov and Yuxin Wu and Kaiming He and Ross Girshick},
journal={ArXiv:1912.08193},
year={2019}
}
```
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
MODEL:
ROI_HEADS:
NAME: "PointRendROIHeads"
IN_FEATURES: ["p2", "p3", "p4", "p5"]
ROI_BOX_HEAD:
TRAIN_ON_PRED_BOXES: True
ROI_MASK_HEAD:
NAME: "CoarseMaskHead"
FC_DIM: 1024
NUM_FC: 2
OUTPUT_SIDE_RESOLUTION: 7
IN_FEATURES: ["p2"]
POINT_HEAD_ON: True
POINT_HEAD:
FC_DIM: 256
NUM_FC: 3
IN_FEATURES: ["p2"]
INPUT:
# PointRend for instance segmenation does not work with "polygon" mask_format.
MASK_FORMAT: "bitmask"
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
ROI_HEADS:
NUM_CLASSES: 8
POINT_HEAD:
NUM_CLASSES: 8
DATASETS:
TEST: ("cityscapes_fine_instance_seg_val",)
TRAIN: ("cityscapes_fine_instance_seg_train",)
SOLVER:
BASE_LR: 0.01
IMS_PER_BATCH: 8
MAX_ITER: 24000
STEPS: (18000,)
INPUT:
MAX_SIZE_TEST: 2048
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024)
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
SOLVER:
STEPS: (210000, 250000)
MAX_ITER: 270000
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
BACKBONE:
FREEZE_AT: 0
SEM_SEG_HEAD:
NAME: "PointRendSemSegHead"
POINT_HEAD:
NUM_CLASSES: 54
FC_DIM: 256
NUM_FC: 3
IN_FEATURES: ["p2"]
TRAIN_NUM_POINTS: 1024
SUBDIVISION_STEPS: 2
SUBDIVISION_NUM_POINTS: 8192
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead"
DATASETS:
TRAIN: ("coco_2017_train_panoptic_stuffonly",)
TEST: ("coco_2017_val_panoptic_stuffonly",)
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl
RESNETS:
DEPTH: 101
SEM_SEG_HEAD:
NUM_CLASSES: 19
POINT_HEAD:
NUM_CLASSES: 19
TRAIN_NUM_POINTS: 2048
SUBDIVISION_NUM_POINTS: 8192
DATASETS:
TRAIN: ("cityscapes_fine_sem_seg_train",)
TEST: ("cityscapes_fine_sem_seg_val",)
SOLVER:
BASE_LR: 0.01
STEPS: (40000, 55000)
MAX_ITER: 65000
IMS_PER_BATCH: 32
INPUT:
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048)
MIN_SIZE_TRAIN_SAMPLING: "choice"
MIN_SIZE_TEST: 1024
MAX_SIZE_TRAIN: 4096
MAX_SIZE_TEST: 2048
CROP:
ENABLED: True
TYPE: "absolute"
SIZE: (512, 1024)
SINGLE_CATEGORY_MAX_AREA: 0.75
COLOR_AUG_SSD: True
DATALOADER:
NUM_WORKERS: 16
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
RESNETS:
DEPTH: 50
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import add_pointrend_config
from .coarse_mask_head import CoarseMaskHead
from .roi_heads import PointRendROIHeads
from .dataset_mapper import SemSegDatasetMapper
from .semantic_seg import PointRendSemSegHead
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import Conv2d, ShapeSpec
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
@ROI_MASK_HEAD_REGISTRY.register()
class CoarseMaskHead(nn.Module):
"""
A mask head with fully connected layers. Given pooled features it first reduces channels and
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously
to the standard box head.
"""
def __init__(self, cfg, input_shape: ShapeSpec):
"""
The following attributes are parsed from config:
conv_dim: the output dimension of the conv layers
fc_dim: the feature dimenstion of the FC layers
num_fc: the number of FC layers
output_side_resolution: side resolution of the output square mask prediction
"""
super(CoarseMaskHead, self).__init__()
# fmt: off
self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC
self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION
self.input_channels = input_shape.channels
self.input_h = input_shape.height
self.input_w = input_shape.width
# fmt: on
self.conv_layers = []
if self.input_channels > conv_dim:
self.reduce_channel_dim_conv = Conv2d(
self.input_channels,
conv_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
activation=F.relu,
)
self.conv_layers.append(self.reduce_channel_dim_conv)
self.reduce_spatial_dim_conv = Conv2d(
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu
)
self.conv_layers.append(self.reduce_spatial_dim_conv)
input_dim = conv_dim * self.input_h * self.input_w
input_dim //= 4
self.fcs = []
for k in range(num_fc):
fc = nn.Linear(input_dim, self.fc_dim)
self.add_module("coarse_mask_fc{}".format(k + 1), fc)
self.fcs.append(fc)
input_dim = self.fc_dim
output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution
self.prediction = nn.Linear(self.fc_dim, output_dim)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.prediction.weight, std=0.001)
nn.init.constant_(self.prediction.bias, 0)
for layer in self.conv_layers:
weight_init.c2_msra_fill(layer)
for layer in self.fcs:
weight_init.c2_xavier_fill(layer)
def forward(self, x):
# unlike BaseMaskRCNNHead, this head only outputs intermediate
# features, because the features will be used later by PointHead.
N = x.shape[0]
x = x.view(N, self.input_channels, self.input_h, self.input_w)
for layer in self.conv_layers:
x = layer(x)
x = torch.flatten(x, start_dim=1)
for layer in self.fcs:
x = F.relu(layer(x))
return self.prediction(x).view(
N, self.num_classes, self.output_side_resolution, self.output_side_resolution
)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import random
import cv2
from fvcore.transforms.transform import Transform
class ColorAugSSDTransform(Transform):
"""
A color related data augmentation used in Single Shot Multibox Detector (SSD).
Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy,
Scott Reed, Cheng-Yang Fu, Alexander C. Berg.
SSD: Single Shot MultiBox Detector. ECCV 2016.
Implementation based on:
https://github.com/weiliu89/caffe/blob
/4817bf8b4200b35ada8ed0dc378dceaf38c539e4
/src/caffe/util/im_transforms.cpp
https://github.com/chainer/chainercv/blob
/7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv
/links/model/ssd/transforms.py
"""
def __init__(
self,
img_format,
brightness_delta=32,
contrast_low=0.5,
contrast_high=1.5,
saturation_low=0.5,
saturation_high=1.5,
hue_delta=18,
):
super().__init__()
assert img_format in ["BGR", "RGB"]
self.is_rgb = img_format == "RGB"
del img_format
self._set_attributes(locals())
def apply_coords(self, coords):
return coords
def apply_segmentation(self, segmentation):
return segmentation
def apply_image(self, img, interp=None):
if self.is_rgb:
img = img[:, :, [2, 1, 0]]
img = self.brightness(img)
if random.randrange(2):
img = self.contrast(img)
img = self.saturation(img)
img = self.hue(img)
else:
img = self.saturation(img)
img = self.hue(img)
img = self.contrast(img)
if self.is_rgb:
img = img[:, :, [2, 1, 0]]
return img
def convert(self, img, alpha=1, beta=0):
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
if random.randrange(2):
return self.convert(
img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)
)
return img
def contrast(self, img):
if random.randrange(2):
return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high))
return img
def saturation(self, img):
if random.randrange(2):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 1] = self.convert(
img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high)
)
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
def hue(self, img):
if random.randrange(2):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 0] = (
img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)
) % 180
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.config import CfgNode as CN
def add_pointrend_config(cfg):
"""
Add config for PointRend.
"""
# We retry random cropping until no single category in semantic segmentation GT occupies more
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
# Color augmentatition from SSD paper for semantic segmentation model during training.
cfg.INPUT.COLOR_AUG_SSD = False
# Names of the input feature maps to be used by a coarse mask head.
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024
cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2
# The side size of a coarse mask head prediction.
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7
# True if point head is used.
cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False
cfg.MODEL.POINT_HEAD = CN()
cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead"
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80
# Names of the input feature maps to be used by a mask point head.
cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",)
# Number of points sampled during training for a mask point head.
cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
# original paper.
cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
# the original paper.
cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75
# Number of subdivision steps during inference.
cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5
# Maximum number of points selected at each subdivision step (N).
cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28
cfg.MODEL.POINT_HEAD.FC_DIM = 256
cfg.MODEL.POINT_HEAD.NUM_FC = 3
cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
# If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from fvcore.common.file_io import PathManager
from fvcore.transforms.transform import CropTransform
from PIL import Image
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from .color_augmentation import ColorAugSSDTransform
"""
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models.
Unlike the default DatasetMapper this mapper uses cropping as the last transformation.
"""
__all__ = ["SemSegDatasetMapper"]
class SemSegDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by semantic segmentation models.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""
def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
else:
self.crop_gen = None
self.tfm_gens = utils.build_transform_gen(cfg, is_train)
if cfg.INPUT.COLOR_AUG_SSD:
self.tfm_gens.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
logging.getLogger(__name__).info(
"Color augmnetation used in training: " + str(self.tfm_gens[-1])
)
# fmt: off
self.img_format = cfg.INPUT.FORMAT
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
# fmt: on
self.is_train = is_train
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
assert "sem_seg_file_name" in dataset_dict
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.is_train:
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
sem_seg_gt = Image.open(f)
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
if self.crop_gen:
image, sem_seg_gt = crop_transform(
image,
sem_seg_gt,
self.crop_gen,
self.single_category_max_area,
self.ignore_value,
)
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if not self.is_train:
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
return dataset_dict
def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value):
"""
Find a cropping window such that no single category occupies more than
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max.
"""
if single_category_max_area >= 1.0:
crop_tfm = crop_gen.get_transform(image)
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg)
else:
h, w = sem_seg.shape
crop_size = crop_gen.get_crop_size((h, w))
for _ in range(10):
y0 = np.random.randint(h - crop_size[0] + 1)
x0 = np.random.randint(w - crop_size[1] + 1)
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
cnt = cnt[labels != ignore_value]
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area:
break
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
image = crop_tfm.apply_image(image)
return image, sem_seg_temp
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.structures import Boxes
"""
Shape shorthand in this module:
N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the
number of images for semantic segmenation.
R: number of ROIs, combined over all images, in the minibatch
P: number of points
"""
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def generate_regular_grid_point_coords(R, side_size, device):
"""
Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.
Args:
R (int): The number of grids to sample, one for each region.
side_size (int): The side size of the regular grid.
device (torch.device): Desired device of returned tensor.
Returns:
(Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates
for the regular grids.
"""
aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)
r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)
return r.view(1, -1, 2).expand(R, -1, -1)
def get_uncertain_point_coords_with_randomness(
coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
):
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = coarse_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to incorrect results.
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
# two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
# However, if we calculate uncertainties for the coarse predictions first,
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
num_boxes, num_uncertain_points, 2
)
if num_random_points > 0:
point_coords = cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
],
dim=1,
)
return point_coords
def get_uncertain_point_coords_on_grid(uncertainty_map, num_points):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
values for a set of points on a regular H x W grid.
num_points (int): The number of points P to select.
Returns:
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
[0, H x W) of the most uncertain points.
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
coordinates of the most uncertain points from the H x W grid.
"""
R, _, H, W = uncertainty_map.shape
h_step = 1.0 / float(H)
w_step = 1.0 / float(W)
num_points = min(H * W, num_points)
point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]
point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
return point_indices, point_coords
def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):
"""
Get features from feature maps in `features_list` that correspond to specific point coordinates
inside each bounding box from `boxes`.
Args:
features_list (list[Tensor]): A list of feature map tensors to get features from.
feature_scales (list[float]): A list of scales for tensors in `features_list`.
boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all
together.
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
Returns:
point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled
from all features maps in feature_list for P sampled points for all R boxes in `boxes`.
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level
coordinates of P points.
"""
cat_boxes = Boxes.cat(boxes)
num_boxes = [len(b) for b in boxes]
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)
split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)
point_features = []
for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):
point_features_per_image = []
for idx_feature, feature_map in enumerate(features_list):
h, w = feature_map.shape[-2:]
scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature]
point_coords_scaled = point_coords_wrt_image_per_image / scale
point_features_per_image.append(
point_sample(
feature_map[idx_img].unsqueeze(0),
point_coords_scaled.unsqueeze(0),
align_corners=False,
)
.squeeze(0)
.transpose(1, 0)
)
point_features.append(cat(point_features_per_image, dim=1))
return cat(point_features, dim=0), point_coords_wrt_image
def get_point_coords_wrt_image(boxes_coords, point_coords):
"""
Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.
Args:
boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.
coordinates.
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
Returns:
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains
image-normalized coordinates of P sampled points.
"""
with torch.no_grad():
point_coords_wrt_image = point_coords.clone()
point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (
boxes_coords[:, None, 2] - boxes_coords[:, None, 0]
)
point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (
boxes_coords[:, None, 3] - boxes_coords[:, None, 1]
)
point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]
point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]
return point_coords_wrt_image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment