Unverified Commit ccd1b27d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add Faster R-CNN and Mask R-CNN (#898)

* [Remove] Use stride in 1x1 in resnet

This is temporary

* Move files to torchvision

Inference works

* Now seems to give same results

Was using the wrong number of total iterations in the end...

* Distributed evaluation seems to work

* Factor out transforms into its own file

* Enabling horizontal flips

* MultiStepLR and preparing for launches

* Add warmup

* Clip gt boxes to images

Seems to be crucial to avoid divergence. Also reduces the losses over different processes for better logging

* Single-GPU batch-size 1 of CocoEvaluator works

* Multi-GPU CocoEvaluator works

Gives the exact same results as the other one, and also supports batch size > 1

* Silence prints from pycocotools

* Commenting unneeded code for run

* Fixes

* Improvements and cleanups

* Remove scales from Pooler

It was not a free parameter, and depended only on the feature map dimensions

* Cleanups

* More cleanups

* Add misc ops and totally remove maskrcnn_benchmark

* nit

* Move Pooler to ops

* Make FPN slightly more generic

* Minor improvements or FPN

* Move FPN to ops

* Move functions to utils

* Lint fixes

* More lint

* Minor cleanups

* Add FasterRCNN

* Remove modifications to resnet

* Fixes for Python2

* More lint fixes

* Add aspect ratio grouping

* Move functions around

* Make evaluation use all images for mAP, even those without annotations

* Bugfix with DDP introduced in last commit

* [Check] Remove category mapping

* Lint

* Make GroupedBatchSampler prioritize largest clusters in the end of iteration

* Bugfix for selecting the iou_types during evaluation

Also switch to using the torchvision normalization now on, given that we are using torchvision base models

* More lint

* Add barrier after init_process_group

Better be safe than sorry

* Make evaluation only use one CPU thread per process

When doing multi-gpu evaluation, paste_masks_in_image is multithreaded and throttles evaluation altogether. Also change default for aspect ratio group to match Detectron

* Fix bug in GroupedBatchSampler

After the first epoch, the number of batch elements could be larger than batch_size, because they got accumulated from the previous iteration. Fix this and also rename some variables for more clarity

* Start adding KeypointRCNN

Currently runs and perform inference, need to do full training

* Remove use of opencv in keypoint inference

PyTorch 1.1 adds support for bicubic interpolation which matches opencv (except for empty boxes, where one of the dimensions is 1, but that's fine)

* Remove Masker

Towards having mask postprocessing done inside the model

* Bugfixes in previous change plus cleanups

* Preparing to run keypoint training

* Zero initialize bias for mask heads

* Minor improvements on print

* Towards moving resize to model

Also remove class mapping specific to COCO

* Remove zero init in bias for mask head

Checking if it decreased accuracy

* [CHECK] See if this change brings back expected accuracy

* Cleanups on model and training script

* Remove BatchCollator

* Some cleanups in coco_eval

* Move postprocess to transform

* Revert back scaling and start adding conversion to coco api

The scaling didn't seem to matter

* Use decorator instead of context manager in evaluate

* Move training and evaluation functions to a separate file

Also adds support for obtaining a coco API object from our dataset

* Remove unused code

* Update location of lr_scheduler

Its behavior has changed in PyTorch 1.1

* Remove debug code

* Typo

* Bugfix

* Move image normalization to model

* Remove legacy tensor constructors

Also move away from Int and instead use int64

* Bugfix in MultiscaleRoiAlign

* Move transforms to its own file

* Add missing file

* Lint

* More lint

* Add some basic test for detection models

* More lint
parent 6272c412
import json
import tempfile
import numpy as np
import copy
import time
import torch
import torch._six
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util
from collections import defaultdict
import utils
class CocoEvaluator(object):
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
self.iou_types = iou_types
self.coco_eval = {}
for iou_type in iou_types:
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
self.img_ids = []
self.eval_imgs = {k: [] for k in iou_types}
def update(self, predictions):
img_ids = list(np.unique(list(predictions.keys())))
self.img_ids.extend(img_ids)
for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
coco_dt = loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]
coco_eval.cocoDt = coco_dt
coco_eval.params.imgIds = list(img_ids)
img_ids, eval_imgs = evaluate(coco_eval)
self.eval_imgs[iou_type].append(eval_imgs)
def synchronize_between_processes(self):
for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
def accumulate(self):
for coco_eval in self.coco_eval.values():
coco_eval.accumulate()
def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
coco_eval.summarize()
def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
def prepare_for_coco_detection(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
"bbox": box,
"score": scores[k],
}
for k, box in enumerate(boxes)
]
)
return coco_results
def prepare_for_coco_segmentation(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
scores = prediction["scores"]
labels = prediction["labels"]
masks = prediction["mask"]
masks = masks > 0.5
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
"segmentation": rle,
"score": scores[k],
}
for k, rle in enumerate(rles)
]
)
return coco_results
def prepare_for_coco_keypoint(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
keypoints = prediction["keypoints"]
keypoints = keypoints.flatten(start_dim=1).tolist()
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
'keypoints': keypoint,
"score": scores[k],
}
for k, keypoint in enumerate(keypoints)
]
)
return coco_results
def convert_to_xywh(boxes):
xmin, ymin, xmax, ymax = boxes.unbind(1)
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
def merge(img_ids, eval_imgs):
all_img_ids = utils.all_gather(img_ids)
all_eval_imgs = utils.all_gather(eval_imgs)
merged_img_ids = []
for p in all_img_ids:
merged_img_ids.extend(p)
merged_eval_imgs = []
for p in all_eval_imgs:
merged_eval_imgs.append(p)
merged_img_ids = np.array(merged_img_ids)
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
# keep only unique (and in sorted order) images
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
merged_eval_imgs = merged_eval_imgs[..., idx]
return merged_img_ids, merged_eval_imgs
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
img_ids, eval_imgs = merge(img_ids, eval_imgs)
img_ids = list(img_ids)
eval_imgs = list(eval_imgs.flatten())
coco_eval.evalImgs = eval_imgs
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################
# Ideally, pycocotools wouldn't have hard-coded prints
# so that we could avoid copy-pasting those two functions
def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann
if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])
# print('index created!')
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
maskUtils = mask_util
def loadRes(self, resFile):
"""
Load result file and return a result api object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]
# print('Loading and preparing results...')
# tic = time.time()
if isinstance(resFile, torch._six.string_classes):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, 'results in not an array of objects'
annsImgIds = [ann['image_id'] for ann in anns]
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
'Results do not correspond to current coco set'
if 'caption' in anns[0]:
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
for id, ann in enumerate(anns):
ann['id'] = id + 1
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
bb = ann['bbox']
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
if 'segmentation' not in ann:
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann['area'] = bb[2] * bb[3]
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'segmentation' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann['area'] = maskUtils.area(ann['segmentation'])
if 'bbox' not in ann:
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
ann['area'] = (x1 - x0) * (y1 - y0)
ann['id'] = id + 1
ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
res.dataset['annotations'] = anns
createIndex(res)
return res
def evaluate(self):
'''
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs
#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
import copy
import os
from PIL import Image
import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
import transforms as T
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask(object):
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if obj['iscrowd'] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
target["area"] = area
target["iscrowd"] = iscrowd
return image, target
def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
min_keypoints_per_image = 10
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)
dataset = torch.utils.data.Subset(dataset, ids)
return dataset
def convert_to_coco_api(ds):
coco_ds = COCO()
ann_id = 0
dataset = {'images': [], 'categories': [], 'annotations': []}
categories = set()
for img_idx in range(len(ds)):
# find better way to get target
# targets = ds.get_annotations(img_idx)
_, targets = ds[img_idx]
image_id = targets["image_id"].item()
img_dict = {}
img_dict['id'] = image_id
dataset['images'].append(img_dict)
bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist()
labels = targets['labels'].tolist()
areas = targets['area'].tolist()
iscrowd = targets['iscrowd'].tolist()
# TODO need to add masks as well
num_objs = len(bboxes)
for i in range(num_objs):
ann = {}
ann['image_id'] = image_id
ann['bbox'] = bboxes[i]
ann['category_id'] = labels[i]
categories.add(labels[i])
ann['area'] = areas[i]
ann['iscrowd'] = iscrowd[i]
ann['id'] = ann_id
dataset['annotations'].append(ann)
ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)]
coco_ds.dataset = dataset
coco_ds.createIndex()
return coco_ds
def get_coco_api_from_dataset(dataset):
for i in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
return convert_to_coco_api(dataset)
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target
def get_coco(root, image_set, transforms, mode='instances'):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
return dataset
def get_coco_kp(root, image_set, transforms):
return get_coco(root, image_set, transforms, mode="person_keypoints")
import math
import sys
import time
import torch
import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
lr_scheduler = None
if epoch == 0:
warmup_factor = 1. / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
loss_value = losses_reduced.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
losses.backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
def _get_iou_types(model):
model_without_ddp = model
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_without_ddp = model.module
iou_types = ["bbox"]
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
iou_types.append("segm")
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
iou_types.append("keypoints")
return iou_types
@torch.no_grad()
def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
torch.set_num_threads(1)
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model)
coco_evaluator = CocoEvaluator(coco, iou_types)
for image, targets in metric_logger.log_every(data_loader, 100, header):
image = list(img.to(device) for img in image)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
torch.cuda.synchronize()
model_time = time.time()
outputs = model(image)
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
coco_evaluator.accumulate()
coco_evaluator.summarize()
torch.set_num_threads(n_threads)
return coco_evaluator
import bisect
from collections import defaultdict
import copy
import numpy as np
import torch
import torch.utils.data
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
import torchvision
from PIL import Image
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a continuous set of integers starting from
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
def __iter__(self):
buffer_per_group = defaultdict(list)
samples_per_group = defaultdict(list)
num_batches = 0
for idx in self.sampler:
group_id = self.group_ids[idx]
buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id]
num_batches += 1
del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size
# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches
if num_remaining > 0:
# for the remaining batches, take first the buffers with largest number
# of elements
for group_id, _ in sorted(buffer_per_group.items(),
key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id])
buffer_per_group[group_id].extend(
samples_per_group[group_id][:remaining])
assert len(buffer_per_group[group_id]) == self.batch_size
yield buffer_per_group[group_id]
num_remaining -= 1
if num_remaining == 0:
break
assert num_remaining == 0
def __len__(self):
return len(self.sampler) // self.batch_size
def _compute_aspect_ratios_slow(dataset, indices=None):
print("Your dataset doesn't support the fast path for "
"computing the aspect ratios, so will iterate over "
"the full dataset and load every image instead. "
"This might take some time...")
if indices is None:
indices = range(len(dataset))
class SubsetSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
sampler = SubsetSampler(indices)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, sampler=sampler,
num_workers=14, # you might want to increase it for faster processing
collate_fn=lambda x: x[0])
aspect_ratios = []
with tqdm(total=len(dataset)) as pbar:
for i, (img, _) in enumerate(data_loader):
pbar.update(1)
height, width = img.shape[-2:]
aspect_ratio = float(height) / float(width)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
height, width = dataset.get_height_and_width(i)
aspect_ratio = float(height) / float(width)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
img_info = dataset.coco.imgs[dataset.ids[i]]
aspect_ratio = float(img_info["height"]) / float(img_info["width"])
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
# this doesn't load the data into memory, because PIL loads it lazily
width, height = Image.open(dataset.images[i]).size
aspect_ratio = float(height) / float(width)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
ds_indices = [dataset.indices[i] for i in indices]
return compute_aspect_ratios(dataset.dataset, ds_indices)
def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.CocoDetection):
return _compute_aspect_ratios_coco_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.VOCDetection):
return _compute_aspect_ratios_voc_dataset(dataset, indices)
if isinstance(dataset, torch.utils.data.Subset):
return _compute_aspect_ratios_subset_dataset(dataset, indices)
# slow path
return _compute_aspect_ratios_slow(dataset, indices)
def _quantize(x, bins):
bins = copy.deepcopy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
def create_aspect_ratio_groups(dataset, k=0):
aspect_ratios = compute_aspect_ratios(dataset)
bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
groups = _quantize(aspect_ratios, bins)
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins))
print("Count of instances per bin: {}".format(counts))
return groups
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
from torchvision import transforms
from coco_utils import get_coco, get_coco_kp
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate
import utils
import transforms as T
def get_dataset(name, image_set, transform):
paths = {
"coco": ('/datasets01/COCO/022719/', get_coco, 91),
"coco_kp": ('/datasets01/COCO/022719/', get_coco_kp, 2)
}
p, ds_fn, num_classes = paths[name]
ds = ds_fn(p, image_set=image_set, transforms=transform)
return ds, num_classes
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Data loading code
print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False))
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
if args.aspect_ratio_group_factor >= 0:
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(
train_sampler, args.batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)
print("Creating model")
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if args.test_only:
evaluate(model, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
if args.output_dir:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
# evaluate after every epoch
evaluate(model, data_loader_test, device=device)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='PyTorch Detection Training')
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
parser.add_argument('--dataset', default='coco', help='dataset')
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=2, type=int)
parser.add_argument('--epochs', default=13, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--aspect-ratio-group-factor', default=0, type=int)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
args = parser.parse_args()
if args.output_dir:
utils.mkdir(args.output_dir)
main(args)
import random
import torch
from torchvision.transforms import functional as F
def _flip_coco_person_keypoints(kps, width):
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
flipped_data = kps[:, flip_inds]
flipped_data[..., 0] = width - flipped_data[..., 0]
# Maintain COCO convention that if visibility == 0, then x, y = 0
inds = flipped_data[..., 2] == 0
flipped_data[inds] = 0
return flipped_data
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, prob):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
return image, target
from __future__ import print_function
from collections import defaultdict, deque
import datetime
import pickle
import time
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def collate_fn(batch):
return tuple(zip(*batch))
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
......@@ -15,6 +15,11 @@ def get_available_segmentation_models():
return [k for k, v in models.segmentation.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def get_available_detection_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.detection.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
class Tester(unittest.TestCase):
def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
......@@ -35,6 +40,17 @@ class Tester(unittest.TestCase):
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
def _test_detection_model(self, name):
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape)
out = model([x])
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
def _make_sliced_model(self, model, stop_layer):
layers = OrderedDict()
for name, layer in model.named_children():
......@@ -77,5 +93,14 @@ for model_name in get_available_segmentation_models():
setattr(Tester, "test_" + model_name, do_test)
for model_name in get_available_detection_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_detection_model(model_name)
setattr(Tester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
......@@ -8,3 +8,4 @@ from .googlenet import *
from .mobilenet import *
from .shufflenetv2 import *
from . import segmentation
from . import detection
from __future__ import division
import math
import torch
class BalancedPositiveNegativeSampler(object):
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
"""
def __init__(self, batch_size_per_image, positive_fraction):
"""
Arguments:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch
"""
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs):
"""
Arguments:
matched idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as
positives.
Returns:
pos_idx (list[tensor])
neg_idx (list[tensor])
Returns two lists of binary masks for each image.
The first list contains the positive elements that were selected,
and the second list the negative example.
"""
pos_idx = []
neg_idx = []
for matched_idxs_per_image in matched_idxs:
positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
num_neg = self.batch_size_per_image - num_pos
# protect against not enough negative examples
num_neg = min(negative.numel(), num_neg)
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx_per_image = positive[perm1]
neg_idx_per_image = negative[perm2]
# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
@torch.jit.script
def encode_boxes(reference_boxes, proposals, weights):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""
Encode a set of proposals with respect to some
reference boxes
Arguments:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
"""
# perform some unpacking to make it JIT-fusion friendly
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
proposals_x1 = proposals[:, 0].unsqueeze(1)
proposals_y1 = proposals[:, 1].unsqueeze(1)
proposals_x2 = proposals[:, 2].unsqueeze(1)
proposals_y2 = proposals[:, 3].unsqueeze(1)
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
# implementation starts here
ex_widths = proposals_x2 - proposals_x1
ex_heights = proposals_y2 - proposals_y1
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
gt_widths = reference_boxes_x2 - reference_boxes_x1
gt_heights = reference_boxes_y2 - reference_boxes_y1
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * torch.log(gt_widths / ex_widths)
targets_dh = wh * torch.log(gt_heights / ex_heights)
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
class BoxCoder(object):
"""
This class encodes and decodes a set of bounding boxes into
the representation used for training the regressors.
"""
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
"""
Arguments:
weights (4-element tuple)
bbox_xform_clip (float)
"""
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals):
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0)
def encode_single(self, reference_boxes, proposals):
"""
Encode a set of proposals with respect to some
reference boxes
Arguments:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
"""
dtype = reference_boxes.dtype
device = reference_boxes.device
weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
targets = encode_boxes(reference_boxes, proposals, weights)
return targets
def decode(self, rel_codes, boxes):
assert isinstance(boxes, (list, tuple))
if isinstance(rel_codes, (list, tuple)):
rel_codes = torch.cat(rel_codes, dim=0)
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [len(b) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
pred_boxes = self.decode_single(
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
)
return pred_boxes.reshape(sum(boxes_per_image), -1, 4)
def decode_single(self, rel_codes, boxes):
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Arguments:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
"""
boxes = boxes.to(rel_codes.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.weights
dx = rel_codes[:, 0::4] / wx
dy = rel_codes[:, 1::4] / wy
dw = rel_codes[:, 2::4] / ww
dh = rel_codes[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.bbox_xform_clip)
dh = torch.clamp(dh, max=self.bbox_xform_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(rel_codes)
# x1
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
# y1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
# y2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h
return pred_boxes
class Matcher(object):
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
element. Each predicted element will have exactly zero or one matches; each
ground-truth element may be assigned to zero or more predicted elements.
Matching is based on the MxN match_quality_matrix, that characterizes how well
each (ground-truth, predicted)-pair match. For example, if the elements are
boxes, the matrix may contain box IoU overlap values.
The matcher returns a tensor of size N containing the index of the ground-truth
element m that matches to prediction n. If there is no match, a negative value
is returned.
"""
BELOW_LOW_THRESHOLD = -1
BETWEEN_THRESHOLDS = -2
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
"""
Args:
high_threshold (float): quality values greater than or equal to
this value are candidate matches.
low_threshold (float): a lower quality threshold used to stratify
matches into three levels:
1) matches >= high_threshold
2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
allow_low_quality_matches (bool): if True, produce additional matches
for predictions that have only low-quality match candidates. See
set_low_quality_matches_ for more details.
"""
assert low_threshold <= high_threshold
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix):
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
pairwise quality between M ground-truth elements and N predicted elements.
Returns:
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
[0, M - 1] or a negative value indicating that prediction i could not
be matched.
"""
if match_quality_matrix.numel() == 0:
# empty targets or proposals not supported during training
if match_quality_matrix.shape[0] == 0:
raise ValueError(
"No ground-truth boxes available for one of the images "
"during training")
else:
raise ValueError(
"No proposal boxes available for one of the images "
"during training")
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
matched_vals, matches = match_quality_matrix.max(dim=0)
if self.allow_low_quality_matches:
all_matches = matches.clone()
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD
matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS
if self.allow_low_quality_matches:
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have
maximum overlap with it (including ties); for each prediction in that set, if
it is unmatched, then match it to the ground-truth with which it has the highest
quality value.
"""
# For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.nonzero(
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
# Example gt_pred_pairs_of_highest_quality:
# tensor([[ 0, 39796],
# [ 1, 32055],
# [ 1, 32070],
# [ 2, 39190],
# [ 2, 40255],
# [ 3, 40390],
# [ 3, 41455],
# [ 4, 45470],
# [ 5, 45325],
# [ 5, 46390]])
# Each row is a (gt index, prediction index)
# Note how gt items 1, 2, 3, and 5 each have two ties
pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""
from collections import OrderedDict
import torch
from torch import nn
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN.
Arguments:
backbone (nn.Module):
rpn (nn.Module):
heads (nn.Module): takes the features + the proposals from the RPN and computes
detections / masks from it.
transform (nn.Module): performs the data transformation from the inputs to feed into
the model
"""
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
def forward(self, images, targets=None):
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([(0, features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if self.training:
return losses
return detections
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import division
import torch
class ImageList(object):
"""
Structure that holds a list of images (of possibly
varying sizes) as a single tensor.
This works by padding the images to the same size,
and storing in a field the original sizes of each image
"""
def __init__(self, tensors, image_sizes):
"""
Arguments:
tensors (tensor)
image_sizes (list[tuple[int, int]])
"""
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
return ImageList(cast_tensor, self.image_sizes)
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform
from .._utils import IntermediateLayerGetter
__all__ = [
"FasterRCNN", "MaskRCNN", "fasterrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn",
"KeypointRCNN", "keypointrcnn_resnet50_fpn"
]
class BackboneWithFPN(nn.Sequential):
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
body = IntermediateLayerGetter(backbone, return_layers=return_layers)
fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=LastLevelMaxPool(),
)
super(BackboneWithFPN, self).__init__(OrderedDict(
[("body", body), ("fpn", fpn)]))
self.out_channels = out_channels
class FasterRCNN(GeneralizedRCNN):
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None):
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
if num_classes is not None:
if box_predictor is not None:
raise ValueError("num_classes should be None when box_predictor is specified")
else:
if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor "
"is not specified")
out_channels = backbone.out_channels
if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
if rpn_head is None:
rpn_head = RPNHead(
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
)
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
output_size=7,
sampling_ratio=2)
if box_head is None:
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
out_channels * resolution ** 2,
representation_size)
if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes)
roi_heads = RoIHeads(
# Box
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
class MaskRCNN(FasterRCNN):
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None,
mask_discretization_size=28):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
if num_classes is not None:
if mask_predictor is not None:
raise ValueError("num_classes should be None when mask_predictor is specified")
out_channels = backbone.out_channels
if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
output_size=14,
sampling_ratio=2)
if mask_head is None:
mask_layers = (256, 256, 256, 256)
mask_dilation = 1
mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
if mask_predictor is None:
mask_dim_reduced = 256 # == mask_layers[-1]
mask_predictor = MaskRCNNC4Predictor(out_channels, mask_dim_reduced, num_classes)
super(MaskRCNN, self).__init__(
backbone, num_classes,
# transform parameters
min_size, max_size,
image_mean, image_std,
# RPN-specific parameters
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights)
self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head
self.roi_heads.mask_predictor = mask_predictor
self.roi_heads.mask_discretization_size = mask_discretization_size
class KeypointRCNN(FasterRCNN):
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# keypoint parameters
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
keypoint_discretization_size=56,
num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if num_classes is not None:
if keypoint_predictor is not None:
raise ValueError("num_classes should be None when keypoint_predictor is specified")
out_channels = backbone.out_channels
if keypoint_roi_pool is None:
keypoint_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
output_size=14,
sampling_ratio=2)
if keypoint_head is None:
keypoint_layers = tuple(512 for _ in range(8))
keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
if keypoint_predictor is None:
keypoint_dim_reduced = 512 # == keypoint_layers[-1]
keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
super(KeypointRCNN, self).__init__(
backbone, num_classes,
# transform parameters
min_size, max_size,
image_mean, image_std,
# RPN-specific parameters
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights)
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head
self.roi_heads.keypoint_predictor = keypoint_predictor
self.roi_heads.keypoint_discretization_size = keypoint_discretization_size
class TwoMLPHead(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, in_channels, representation_size):
super(TwoMLPHead, self).__init__()
self.fc6 = nn.Linear(in_channels, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)
def forward(self, x):
x = x.flatten(start_dim=1)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
return x
class FastRCNNPredictor(nn.Module):
def __init__(self, in_channels, num_classes):
super(FastRCNNPredictor, self).__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
def forward(self, x):
if x.ndimension() == 4:
assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1)
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
return scores, bbox_deltas
class MaskRCNNHeads(nn.Sequential):
def __init__(self, in_channels, layers, dilation):
"""
Arguments:
num_classes (int): number of output classes
input_size (int): number of channels of the input once it's flattened
representation_size (int): size of the intermediate representation
"""
d = OrderedDict()
next_feature = in_channels
for layer_idx, layer_features in enumerate(layers, 1):
d["mask_fcn{}".format(layer_idx)] = misc_nn_ops.Conv2d(
next_feature, layer_features, kernel_size=3,
stride=1, padding=dilation, dilation=dilation)
d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True)
next_feature = layer_features
super(MaskRCNNHeads, self).__init__(d)
for name, param in self.named_parameters():
if "weight" in name:
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
# elif "bias" in name:
# nn.init.constant_(param, 0)
class MaskRCNNC4Predictor(nn.Sequential):
def __init__(self, in_channels, dim_reduced, num_classes):
super(MaskRCNNC4Predictor, self).__init__(OrderedDict([
("conv5_mask", misc_nn_ops.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
("relu", nn.ReLU(inplace=True)),
("mask_fcn_logits", misc_nn_ops.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
]))
for name, param in self.named_parameters():
if "weight" in name:
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
# elif "bias" in name:
# nn.init.constant_(param, 0)
class KeypointRCNNHeads(nn.Sequential):
def __init__(self, in_channels, layers):
d = []
next_feature = in_channels
for l in layers:
d.append(misc_nn_ops.Conv2d(next_feature, l, 3, stride=1, padding=1))
d.append(nn.ReLU(inplace=True))
next_feature = l
super(KeypointRCNNHeads, self).__init__(*d)
for m in self.children():
if isinstance(m, misc_nn_ops.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(m.bias, 0)
class KeypointRCNNPredictor(nn.Module):
def __init__(self, in_channels, num_keypoints):
super(KeypointRCNNPredictor, self).__init__()
input_features = in_channels
deconv_kernel = 4
self.kps_score_lowres = misc_nn_ops.ConvTranspose2d(
input_features,
num_keypoints,
deconv_kernel,
stride=2,
padding=deconv_kernel // 2 - 1,
)
nn.init.kaiming_normal_(
self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
)
nn.init.constant_(self.kps_score_lowres.bias, 0)
self.up_scale = 2
self.out_channels = num_keypoints
def forward(self, x):
x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate(
x, scale_factor=self.up_scale, mode="bilinear", align_corners=False
)
return x
def _resnet_fpn_backbone(backbone_name, pretrained):
from .. import resnet
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d)
# freeze layers
for name, parameter in backbone.named_parameters():
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3}
in_channels_stage2 = 256
in_channels_list = [
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
def fasterrcnn_resnet50_fpn(pretrained=False, num_classes=81, pretrained_backbone=True, **kwargs):
backbone = _resnet_fpn_backbone('resnet50', pretrained_backbone)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
pass
return model
def maskrcnn_resnet50_fpn(pretrained=False, num_classes=81, pretrained_backbone=True, **kwargs):
backbone = _resnet_fpn_backbone('resnet50', pretrained_backbone)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
pass
return model
def keypointrcnn_resnet50_fpn(pretrained=False, num_classes=2, num_keypoints=17,
pretrained_backbone=True, **kwargs):
backbone = _resnet_fpn_backbone('resnet50', pretrained_backbone)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
pass
return model
This diff is collapsed.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
from torch import nn
from torchvision.ops import boxes as box_ops
from . import _utils as det_utils
class AnchorGenerator(nn.Module):
"""
For a set of image sizes and feature maps, computes a set
of anchors
"""
def __init__(
self,
sizes=(128, 256, 512),
aspect_ratios=(0.5, 1.0, 2.0),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
@staticmethod
def generate_anchors(scales, aspect_ratios, device="cpu"):
scales = torch.as_tensor(scales, dtype=torch.float32, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=torch.float32, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, device):
if self.cell_anchors is not None:
return self.cell_anchors
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
def grid_anchors(self, grid_sizes, strides):
anchors = []
for size, stride, base_anchors in zip(
grid_sizes, strides, self.cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
key = tuple(grid_sizes) + tuple(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes)
self.set_cell_anchors(feature_maps[0].device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
return anchors
class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
"""
def __init__(self, in_channels, num_anchors):
"""
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
for l in self.children():
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
def forward(self, x):
logits = []
bbox_reg = []
for feature in x:
t = F.relu(self.conv(feature))
logits.append(self.cls_logits(t))
bbox_reg.append(self.bbox_pred(t))
return logits, bbox_reg
def permute_and_flatten(layer, N, A, C, H, W):
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_prediction_layers(box_cls, box_regression):
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression
):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(
box_cls_per_level, N, A, C, H, W
)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, C)
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
class RegionProposalNetwork(torch.nn.Module):
def __init__(self,
anchor_generator,
head,
#
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
#
pre_nms_top_n, post_nms_top_n, nms_thresh):
"""
Arguments:
"""
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
# used during training
self.box_similarity = box_ops.box_iou
self.proposal_matcher = det_utils.Matcher(
fg_iou_thresh,
bg_iou_thresh,
allow_low_quality_matches=True,
)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
batch_size_per_image, positive_fraction
)
# used during testing
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = 0
@property
def pre_nms_top_n(self):
if self.training:
return self._pre_nms_top_n['training']
return self._pre_nms_top_n['testing']
@property
def post_nms_top_n(self):
if self.training:
return self._post_nms_top_n['training']
return self._post_nms_top_n['testing']
def assign_targets_to_anchors(self, anchors, targets):
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"]
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level):
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset)
offset += num_anchors
return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
num_images = proposals.shape[0]
device = proposals.device
# do not backprop throught objectness
objectness = objectness.detach()
objectness = objectness.reshape(num_images, -1)
levels = [
torch.full((n,), idx, dtype=torch.int64, device=device)
for idx, n in enumerate(num_anchors_per_level)
]
levels = torch.cat(levels, 0)
levels = levels.reshape(1, -1).expand_as(objectness)
# select top_n boxes independently per level before applying nms
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
batch_idx = torch.arange(num_images, device=device)[:, None]
objectness = objectness[batch_idx, top_n_idx]
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]
final_boxes = []
final_scores = []
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n]
boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
"""
Arguments:
anchors (list[list[BoxList]])
objectness (list[Tensor])
pred_bbox_deltas (list[Tensor])
targets (list[BoxList])
Returns:
objectness_loss (Tensor)
box_loss (Tensor
"""
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness = objectness.flatten()
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = F.l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
reduction="sum",
) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
)
return objectness_loss, box_loss
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
# RPN uses all feature maps that are available
features = list(features.values())
objectness, pred_bbox_deltas = self.head(features)
anchors = self.anchor_generator(images, features)
num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(num_images, -1, 4)
boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
losses = {}
if self.training:
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
return boxes, losses
import math
import torch
from torch import nn
from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image
class GeneralizedRCNNTransform(nn.Module):
def __init__(self, min_size, max_size, image_mean, image_std):
super(GeneralizedRCNNTransform, self).__init__()
self.min_size = float(min_size)
self.max_size = float(max_size)
self.image_mean = image_mean
self.image_std = image_std
def forward(self, images, targets=None):
for i in range(len(images)):
image = images[i]
target = targets[i] if targets is not None else targets
if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape))
image = self.normalize(image)
image, target = self.resize(image, target)
images[i] = image
if targets is not None:
targets[i] = target
image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
image_list = ImageList(images, image_sizes)
return image_list, targets
def normalize(self, image):
dtype, device = image.dtype, image.device
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None]
def resize(self, image, target):
h, w = image.shape[-2:]
min_size = min(image.shape[-2:])
max_size = max(image.shape[-2:])
scale_factor = self.min_size / min_size
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0]
if target is None:
return image, target
bbox = target["boxes"]
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
target["boxes"] = bbox
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
target["keypoints"] = keypoints
return image, target
def batch_images(self, images, size_divisible=32):
# concatenate
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
stride = size_divisible
max_size = list(max_size)
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
max_size = tuple(max_size)
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).zero_()
for img, pad_img in zip(images, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs
def postprocess(self, result, image_shapes, original_image_sizes):
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
boxes = pred["boxes"]
boxes = resize_boxes(boxes, im_s, o_im_s)
result[i]["boxes"] = boxes
if "mask" in pred:
masks = pred["mask"]
masks = paste_masks_in_image(masks, boxes, o_im_s)
result[i]["mask"] = masks
if "keypoints" in pred:
keypoints = pred["keypoints"]
keypoints = resize_keypoints(keypoints, im_s, o_im_s)
result[i]["keypoints"] = keypoints
return result
def resize_keypoints(keypoints, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
ratio_h, ratio_w = ratios
resized_data = keypoints.clone()
resized_data[..., 0] *= ratio_w
resized_data[..., 1] *= ratio_h
return resized_data
def resize_boxes(boxes, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)
xmin = xmin * ratio_width
xmax = xmax * ratio_width
ymin = ymin * ratio_height
ymax = ymax * ratio_height
return torch.stack((xmin, ymin, xmax, ymax), dim=1)
from .boxes import nms, box_iou
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork
__all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool'
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
]
......@@ -45,6 +45,8 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
......
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn
class FeaturePyramidNetwork(nn.Module):
"""
Module that adds a FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(
self, in_channels_list, out_channels, extra_blocks=None
):
"""
Arguments:
in_channels_list (list[int]): number of channels for each feature map that
will be fed
out_channels (int): number of channels of the FPN representation
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names
"""
super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list:
if in_channels == 0:
continue
inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.inner_blocks.append(inner_block_module)
self.layer_blocks.append(layer_block_module)
# initialize parameters now to avoid modifying the initialization of top_blocks
for m in self.children():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0)
if extra_blocks is not None:
assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks
def forward(self, x):
"""
Arguments:
x (OrderedDict[Tensor]): feature maps for each feature level.
Returns:
results (OrderedDict[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first.
"""
# unpack OrderedDict into two lists for easier handling
names = list(x.keys())
x = list(x.values())
last_inner = self.inner_blocks[-1](x[-1])
results = []
results.append(self.layer_blocks[-1](last_inner))
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
if not inner_block:
continue
inner_lateral = inner_block(feature)
feat_shape = inner_lateral.shape[-2:]
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = inner_lateral + inner_top_down
results.insert(0, layer_block(last_inner))
if self.extra_blocks is not None:
results, names = self.extra_blocks(results, x, names)
# make it back an OrderedDict
out = OrderedDict([(k, v) for k, v in zip(names, results)])
return out
class ExtraFPNBlock(nn.Module):
def forward(self, results, x, names):
pass
class LastLevelMaxPool(ExtraFPNBlock):
def forward(self, x, y, names):
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
class LastLevelP6P7(ExtraFPNBlock):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, p, c, names):
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
p.extend([p6, p7])
names.extend(["p6", "p7"])
return p, names
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment