Unverified Commit 8233c9cd authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Properly handle maskrcnn and keypoints w.r.t. V2 in detection references (#7742)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 1402eb8e
import copy
import os import os
import torch import torch
...@@ -10,24 +9,6 @@ from pycocotools.coco import COCO ...@@ -10,24 +9,6 @@ from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.datasets import wrap_dataset_for_transforms_v2
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_coco_poly_to_mask(segmentations, height, width): def convert_coco_poly_to_mask(segmentations, height, width):
masks = [] masks = []
for polygons in segmentations: for polygons in segmentations:
...@@ -219,7 +200,7 @@ class CocoDetection(torchvision.datasets.CocoDetection): ...@@ -219,7 +200,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target return img, target
def get_coco(root, image_set, transforms, mode="instances", use_v2=False): def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json" anno_file_template = "{}_{}2017.json"
PATHS = { PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
...@@ -233,9 +214,12 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): ...@@ -233,9 +214,12 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
if use_v2: if use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation! target_keys = ["boxes", "labels", "image_id"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else: else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()] t = [ConvertCocoPolysToMask()]
if transforms is not None: if transforms is not None:
t.append(transforms) t.append(transforms)
...@@ -249,9 +233,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): ...@@ -249,9 +233,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
return dataset return dataset
def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
...@@ -28,7 +28,7 @@ import torchvision ...@@ -28,7 +28,7 @@ import torchvision
import torchvision.models.detection import torchvision.models.detection
import torchvision.models.detection.mask_rcnn import torchvision.models.detection.mask_rcnn
import utils import utils
from coco_utils import get_coco, get_coco_kp from coco_utils import get_coco
from engine import evaluate, train_one_epoch from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -42,10 +42,16 @@ def copypaste_collate_fn(batch): ...@@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):
def get_dataset(is_train, args): def get_dataset(is_train, args):
image_set = "train" if is_train else "val" image_set = "train" if is_train else "val"
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
p, ds_fn, num_classes = paths[args.dataset] with_masks = "mask" in args.model
ds = get_coco(
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes return ds, num_classes
...@@ -68,7 +74,12 @@ def get_args_parser(add_help=True): ...@@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name") parser.add_argument(
"--dataset",
default="coco",
type=str,
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
)
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument( parser.add_argument(
...@@ -171,6 +182,12 @@ def get_args_parser(add_help=True): ...@@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
def main(args): def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2: if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.") raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment