Unverified Commit bb3aae7b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add --backend and --use-v2 support to detection refs (#7732)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 08c9938f
...@@ -15,6 +15,9 @@ def get_module(use_v2): ...@@ -15,6 +15,9 @@ def get_module(use_v2):
class ClassificationPresetTrain: class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__( def __init__(
self, self,
*, *,
...@@ -30,42 +33,42 @@ class ClassificationPresetTrain: ...@@ -30,42 +33,42 @@ class ClassificationPresetTrain:
backend="pil", backend="pil",
use_v2=False, use_v2=False,
): ):
module = get_module(use_v2) T = get_module(use_v2)
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "tensor": if backend == "tensor":
transforms.append(module.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0: if hflip_prob > 0:
transforms.append(module.RandomHorizontalFlip(hflip_prob)) transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide": elif auto_augment_policy == "ta_wide":
transforms.append(module.TrivialAugmentWide(interpolation=interpolation)) transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix": elif auto_augment_policy == "augmix":
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity)) transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else: else:
aa_policy = module.AutoAugmentPolicy(auto_augment_policy) aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation)) transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil": if backend == "pil":
transforms.append(module.PILToTensor()) transforms.append(T.PILToTensor())
transforms.extend( transforms.extend(
[ [
module.ConvertImageDtype(torch.float), T.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
] ]
) )
if random_erase_prob > 0: if random_erase_prob > 0:
transforms.append(module.RandomErasing(p=random_erase_prob)) transforms.append(T.RandomErasing(p=random_erase_prob))
self.transforms = module.Compose(transforms) self.transforms = T.Compose(transforms)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
...@@ -83,28 +86,28 @@ class ClassificationPresetEval: ...@@ -83,28 +86,28 @@ class ClassificationPresetEval:
backend="pil", backend="pil",
use_v2=False, use_v2=False,
): ):
module = get_module(use_v2) T = get_module(use_v2)
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "tensor": if backend == "tensor":
transforms.append(module.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms += [ transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True), T.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size), T.CenterCrop(crop_size),
] ]
if backend == "pil": if backend == "pil":
transforms.append(module.PILToTensor()) transforms.append(T.PILToTensor())
transforms += [ transforms += [
module.ConvertImageDtype(torch.float), T.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
] ]
self.transforms = module.Compose(transforms) self.transforms = T.Compose(transforms)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
...@@ -7,6 +7,7 @@ import torchvision ...@@ -7,6 +7,7 @@ import torchvision
import transforms as T import transforms as T
from pycocotools import mask as coco_mask from pycocotools import mask as coco_mask
from pycocotools.coco import COCO from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2
class FilterAndRemapCocoCategories: class FilterAndRemapCocoCategories:
...@@ -49,7 +50,6 @@ class ConvertCocoPolysToMask: ...@@ -49,7 +50,6 @@ class ConvertCocoPolysToMask:
w, h = image.size w, h = image.size
image_id = target["image_id"] image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"] anno = target["annotations"]
...@@ -126,10 +126,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None): ...@@ -126,10 +126,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return True return True
return False return False
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = [] ids = []
for ds_idx, img_id in enumerate(dataset.ids): for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
...@@ -196,12 +192,15 @@ def convert_to_coco_api(ds): ...@@ -196,12 +192,15 @@ def convert_to_coco_api(ds):
def get_coco_api_from_dataset(dataset): def get_coco_api_from_dataset(dataset):
# FIXME: This is... awful?
for _ in range(10): for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection): if isinstance(dataset, torchvision.datasets.CocoDetection):
break break
if isinstance(dataset, torch.utils.data.Subset): if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection): if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return dataset.coco return dataset.coco
return convert_to_coco_api(dataset) return convert_to_coco_api(dataset)
...@@ -220,7 +219,7 @@ class CocoDetection(torchvision.datasets.CocoDetection): ...@@ -220,7 +219,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target return img, target
def get_coco(root, image_set, transforms, mode="instances"): def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
anno_file_template = "{}_{}2017.json" anno_file_template = "{}_{}2017.json"
PATHS = { PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
...@@ -228,17 +227,21 @@ def get_coco(root, image_set, transforms, mode="instances"): ...@@ -228,17 +227,21 @@ def get_coco(root, image_set, transforms, mode="instances"):
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
} }
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
img_folder, ann_file = PATHS[image_set] img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder) img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file) ann_file = os.path.join(root, ann_file)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms) if use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
else:
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train": if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset) dataset = _coco_remove_images_without_annotations(dataset)
...@@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"): ...@@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"):
return dataset return dataset
def get_coco_kp(root, image_set, transforms): def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints") return get_coco(root, image_set, transforms, mode="person_keypoints")
...@@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc ...@@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
for images, targets in metric_logger.log_every(data_loader, print_freq, header): for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images) images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None): with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets) loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values()) losses = sum(loss for loss in loss_dict.values())
...@@ -97,7 +97,7 @@ def evaluate(model, data_loader, device): ...@@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time model_time = time.time() - model_time
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time() evaluator_time = time.time()
coco_evaluator.update(res) coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time evaluator_time = time.time() - evaluator_time
......
...@@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None): ...@@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"): if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices) return _compute_aspect_ratios_custom_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.CocoDetection): if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return _compute_aspect_ratios_coco_dataset(dataset, indices) return _compute_aspect_ratios_coco_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.VOCDetection): if isinstance(dataset, torchvision.datasets.VOCDetection):
......
from collections import defaultdict
import torch import torch
import transforms as T import transforms as reference_transforms
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
return torchvision.transforms.v2, torchvision.datapoints
else:
return reference_transforms, None
class DetectionPresetTrain: class DetectionPresetTrain:
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): # Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter.
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):
T, datapoints = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
if data_augmentation == "hflip": if data_augmentation == "hflip":
self.transforms = T.Compose( transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "lsj": elif data_augmentation == "lsj":
self.transforms = T.Compose( transforms += [
[ T.ScaleJitter(target_size=(1024, 1024), antialias=True),
T.ScaleJitter(target_size=(1024, 1024)), # TODO: FixedSizeCrop below doesn't work on tensors!
T.FixedSizeCrop(size=(1024, 1024), fill=mean), reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(), ]
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "multiscale": elif data_augmentation == "multiscale":
self.transforms = T.Compose( transforms += [
[ T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomShortestSize( T.RandomHorizontalFlip(p=hflip_prob),
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 ]
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "ssd": elif data_augmentation == "ssd":
self.transforms = T.Compose( fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
[ transforms += [
T.RandomPhotometricDistort(), T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)), T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(), T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(), ]
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "ssdlite": elif data_augmentation == "ssdlite":
self.transforms = T.Compose( transforms += [
[ T.RandomIoUCrop(),
T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob),
T.RandomHorizontalFlip(p=hflip_prob), ]
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
else: else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
transforms += [T.ConvertImageDtype(torch.float)]
if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(),
]
self.transforms = T.Compose(transforms)
def __call__(self, img, target): def __call__(self, img, target):
return self.transforms(img, target) return self.transforms(img, target)
class DetectionPresetEval: class DetectionPresetEval:
def __init__(self): def __init__(self, backend="pil", use_v2=False):
self.transforms = T.Compose( T, _ = get_modules(use_v2)
[ transforms = []
T.PILToTensor(), backend = backend.lower()
T.ConvertImageDtype(torch.float), if backend == "pil":
] # Note: we could just convert to pure tensors even in v2?
) transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
transforms += [T.ConvertImageDtype(torch.float)]
self.transforms = T.Compose(transforms)
def __call__(self, img, target): def __call__(self, img, target):
return self.transforms(img, target) return self.transforms(img, target)
...@@ -40,23 +40,26 @@ def copypaste_collate_fn(batch): ...@@ -40,23 +40,26 @@ def copypaste_collate_fn(batch):
return copypaste(*utils.collate_fn(batch)) return copypaste(*utils.collate_fn(batch))
def get_dataset(name, image_set, transform, data_path): def get_dataset(is_train, args):
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} image_set = "train" if is_train else "val"
p, ds_fn, num_classes = paths[name] paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[args.dataset]
ds = ds_fn(p, image_set=image_set, transforms=transform) ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
return ds, num_classes return ds, num_classes
def get_transform(train, args): def get_transform(is_train, args):
if train: if is_train:
return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) return presets.DetectionPresetTrain(
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
)
elif args.weights and args.test_only: elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights) weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms() trans = weights.transforms()
return lambda img, target: (trans(img), target) return lambda img, target: (trans(img), target)
else: else:
return presets.DetectionPresetEval() return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
...@@ -159,10 +162,16 @@ def get_args_parser(add_help=True): ...@@ -159,10 +162,16 @@ def get_args_parser(add_help=True):
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.", help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
) )
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser return parser
def main(args): def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -177,8 +186,8 @@ def main(args): ...@@ -177,8 +186,8 @@ def main(args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) dataset, num_classes = get_dataset(is_train=True, args=args)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) dataset_test, _ = get_dataset(is_train=False, args=args)
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
......
...@@ -293,11 +293,13 @@ class ScaleJitter(nn.Module): ...@@ -293,11 +293,13 @@ class ScaleJitter(nn.Module):
target_size: Tuple[int, int], target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0), scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
): ):
super().__init__() super().__init__()
self.target_size = target_size self.target_size = target_size
self.scale_range = scale_range self.scale_range = scale_range
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def forward( def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
...@@ -315,14 +317,17 @@ class ScaleJitter(nn.Module): ...@@ -315,14 +317,17 @@ class ScaleJitter(nn.Module):
new_width = int(orig_width * r) new_width = int(orig_width * r)
new_height = int(orig_height * r) new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
if target is not None: if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target: if "masks" in target:
target["masks"] = F.resize( target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
) )
return image, target return image, target
......
...@@ -1133,7 +1133,7 @@ class TestRefDetTransforms: ...@@ -1133,7 +1133,7 @@ class TestRefDetTransforms:
{"with_mask": False}, {"with_mask": False},
), ),
(det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
( (
det_transforms.RandomShortestSize( det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment