Unverified Commit 1d646d41 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port prototype tests to new utilities (#8022)

parent 67f3ce28
import collections.abc
import dataclasses
from typing import Optional, Sequence
import pytest
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import tv_tensors
from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader
@dataclasses.dataclass
class LabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def _parse_categories(categories):
if categories is None:
num_categories = int(torch.randint(1, 11, ()))
elif isinstance(categories, int):
num_categories = categories
categories = [f"category{idx}" for idx in range(num_categories)]
elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories):
categories = list(categories)
num_categories = len(categories)
else:
raise pytest.UsageError(
f"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f"but got '{categories}' instead."
)
return categories, num_categories
def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return tv_tensors.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
make_label = from_loader(make_label_loader)
@dataclasses.dataclass
class OneHotLabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
if num_categories == 0:
data = torch.empty(shape, dtype=dtype, device=device)
else:
# The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return tv_tensors.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
def make_one_hot_label_loaders(
*,
categories=(1, 0, None),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.int64, torch.float32),
):
for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes):
yield make_one_hot_label_loader(**params)
make_one_hot_labels = from_loaders(make_one_hot_label_loaders)
import collections.abc
import re
import PIL.Image
import pytest
import torch
from common_utils import assert_equal
from common_utils import assert_equal, make_bounding_boxes, make_detection_masks, make_image, make_video
from prototype_common_utils import make_label
from torchvision.prototype import transforms, tv_tensors
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
make_detection_mask,
make_image,
make_video,
)
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
def _parse_categories(categories):
if categories is None:
num_categories = int(torch.randint(1, 11, ()))
elif isinstance(categories, int):
num_categories = categories
categories = [f"category{idx}" for idx in range(num_categories)]
elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories):
categories = list(categories)
num_categories = len(categories)
else:
raise pytest.UsageError(
f"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f"but got '{categories}' instead."
)
return categories, num_categories
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)
def make_label(*, extra_dims=(), categories=10, dtype=torch.int64, device="cpu"):
categories, num_categories = _parse_categories(categories)
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(extra_dims, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return tv_tensors.Label(data, categories=categories)
class TestSimpleCopyPaste:
......@@ -167,7 +168,7 @@ class TestFixedSizeCrop:
flat_inputs = [
make_image(size=canvas_size, color_space="RGB"),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
]
params = transform._get_params(flat_inputs)
......@@ -203,9 +204,9 @@ class TestFixedSizeCrop:
)
bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size
)
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
masks = make_detection_masks(size=canvas_size, num_masks=batch_size)
labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
......@@ -241,7 +242,7 @@ class TestFixedSizeCrop:
)
bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size
)
mock = mocker.patch(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
......@@ -389,27 +390,27 @@ def test_fixed_sized_crop_against_detection_reference():
pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}
yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}
yield (tensor_image, target)
tv_tensor_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}
yield (tv_tensor_image, target)
......
......@@ -6,7 +6,6 @@ implemented there and must not use any of the utilities here.
The following legacy modules depend on this module
- test_transforms_v2_consistency.py
- test_prototype_transforms.py
"""
import collections.abc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment