Unverified Commit b9b7cfc6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add --backend and --use-v2 support for segmentation references (#7743)

parent 8233c9cd
......@@ -6,7 +6,6 @@ import torchvision
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2
def convert_coco_poly_to_mask(segmentations, height, width):
......@@ -213,6 +212,8 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_m
ann_file = os.path.join(root, ann_file)
if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
......
......@@ -68,11 +68,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
......@@ -86,7 +81,7 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return dataset
def get_coco(root, image_set, transforms):
def get_coco(root, image_set, transforms, use_v2=False):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
......@@ -94,13 +89,24 @@ def get_coco(root, image_set, transforms):
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# The 2 "Compose" below achieve the same thing: converting coco detection
# samples into segmentation-compatible samples. They just do it with
# slightly different implementations. We could refactor and unify, but
# keeping them separate helps keeping the v2 version clean
if use_v2:
import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2
transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
else:
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
......
from collections import defaultdict
import torch
import transforms as T
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import v2_extras
return torchvision.transforms.v2, torchvision.datapoints, v2_extras
else:
import transforms
return transforms, None, None
class SegmentationPresetTrain:
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
def __init__(
self,
*,
base_size,
crop_size,
hflip_prob=0.5,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
backend="pil",
use_v2=False,
):
T, datapoints, v2_extras = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend(
[
T.RandomCrop(crop_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
transforms += [T.RandomHorizontalFlip(hflip_prob)]
if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
transforms += [T.RandomCrop(crop_size)]
if backend == "pil":
transforms += [T.PILToTensor()]
if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
]
)
self.transforms = T.Compose(trans)
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.Normalize(mean=mean, std=std)]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
return self.transforms(img, target)
class SegmentationPresetEval:
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose(
[
T.RandomResize(base_size, base_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
def __init__(
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
):
T, _, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
if use_v2:
transforms += [T.Resize(size=(base_size, base_size))]
else:
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
return self.transforms(img, target)
......@@ -14,24 +14,30 @@ from torch.optim.lr_scheduler import PolynomialLR
from torchvision.transforms import functional as F, InterpolationMode
def get_dataset(dir_path, name, image_set, transform):
def get_dataset(args, is_train):
def sbd(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
def voc(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.VOCSegmentation(*args, **kwargs)
paths = {
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
"voc_aug": (dir_path, sbd, 21),
"coco": (dir_path, get_coco, 21),
"voc": (args.data_path, voc, 21),
"voc_aug": (args.data_path, sbd, 21),
"coco": (args.data_path, get_coco, 21),
}
p, ds_fn, num_classes = paths[name]
p, ds_fn, num_classes = paths[args.dataset]
ds = ds_fn(p, image_set=image_set, transforms=transform)
image_set = "train" if is_train else "val"
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
return ds, num_classes
def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
def get_transform(is_train, args):
if is_train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
......@@ -44,7 +50,7 @@ def get_transform(train, args):
return preprocessing
else:
return presets.SegmentationPresetEval(base_size=520)
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
def criterion(inputs, target):
......@@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
if args.use_v2 and args.dataset != "coco":
raise ValueError("v2 is only support supported for coco dataset for now.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -134,8 +146,8 @@ def main(args):
else:
torch.backends.cudnn.benchmark = True
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
dataset, num_classes = get_dataset(args, is_train=True)
dataset_test, _ = get_dataset(args, is_train=False)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......@@ -307,6 +319,8 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser
......
......@@ -35,7 +35,7 @@ class RandomResize:
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
image = F.resize(image, size, antialias=True)
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
return image, target
......
......@@ -267,9 +267,9 @@ def init_distributed_mode(args):
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
......
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
import torch
from torchvision import datapoints
from torchvision.transforms import v2
class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)
def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)
def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt
fill = self.fill[type(inpt)]
fill = v2._utils._convert_fill_arg(fill)
return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
class CocoDetectionToVOCSegmentation(v2.Transform):
"""Turn samples from datasets.CocoDetection into the same format as VOCSegmentation.
This is achieved in two steps:
1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately,
the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not
present in VOC are dropped and replaced by background.
2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual
mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where
the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation
mask while pixels that belong to multiple detection masks are marked as invalid.
"""
COCO_TO_VOC_LABEL_MAP = dict(
zip(
[0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72],
range(21),
)
)
INVALID_VALUE = 255
def _coco_detection_masks_to_voc_segmentation_mask(self, target):
if "masks" not in target:
return None
instance_masks, instance_labels_coco = target["masks"], target["labels"]
valid_labels_voc = [
(idx, label_voc)
for idx, label_coco in enumerate(instance_labels_coco.tolist())
if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None
]
if not valid_labels_voc:
return None
valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc)
instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8)
instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8)
# Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as
# there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step.
segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0)
segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE
return segmentation_mask
def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
return image, datapoints.Mask(segmentation_mask)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment