Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
import functools
import itertools
import numpy as np
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
from transforms_v2_legacy_utils import (
ArgsKwargs,
combinations_grid,
DEFAULT_PORTRAIT_SPATIAL_SIZE,
get_num_channels,
ImageLoader,
InfoBase,
make_bounding_box_loader,
make_bounding_box_loaders,
make_detection_mask_loader,
make_image_loader,
make_image_loaders,
make_image_loaders_for_interpolation,
make_mask_loaders,
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TestMark,
)
__all__ = ["KernelInfo", "KERNEL_INFOS"]
class KernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
# should not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn,
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
# take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
# happen inside the function. It should return a tensor or to be more precise an object that can be compared to
# a tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn=None,
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None,
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
# reference inputs. This is usually used whenever we use a PIL kernel as reference.
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8=False,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage=False,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.kernel = kernel
self.sample_inputs_fn = sample_inputs_fn
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn
if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8
self.logs_usage = logs_usage
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
def cuda_vs_cpu_pixel_difference(atol=1):
return {
(("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
for dtype in [torch.uint8, torch.float32]
}
def pil_reference_pixel_difference(atol=1, mae=False):
return {
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
atol, mae=mae
)
}
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
return {
(
("TestKernels", "test_float32_vs_uint8"),
torch.float32,
"cpu",
): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
}
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
}
def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(input_tensor, *other_args, **kwargs):
if input_tensor.dtype != torch.uint8:
raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
if input_tensor.ndim > 3:
raise pytest.UsageError(
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
)
input_pil = F.to_pil_image(input_tensor)
output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image):
return output_pil
output_tensor = F.to_image(output_pil)
# 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3:
output_tensor = output_tensor.unsqueeze(0)
elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
output_tensor = output_tensor.squeeze(0)
return output_tensor
return wrapper
def xfail_jit(reason, *, condition=None):
return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)
def xfail_jit_python_scalar_arg(name, *, reason=None):
return xfail_jit(
reason or f"Python scalar int or float for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
)
KERNEL_INFOS = []
def get_fills(*, num_channels, dtype):
yield None
int_value = get_max_value(dtype)
float_value = int_value / 2
yield int_value
yield float_value
for vector_type in [list, tuple]:
yield vector_type([int_value])
yield vector_type([float_value])
if num_channels > 1:
yield vector_type(float_value * c / 10 for c in range(num_channels))
yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
def float32_vs_uint8_fill_adapter(other_args, kwargs):
fill = kwargs.get("fill")
if fill is None:
return other_args, kwargs
if isinstance(fill, (int, float)):
fill /= 255
else:
fill = type(fill)(fill_ / 255 for fill_ in fill)
return other_args, dict(kwargs, fill=fill)
def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix):
def transform(bbox, affine_matrix_, format_, canvas_size_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True,
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix_.T)
out_bbox = torch.tensor(
[
np.min(transformed_points[:, 0]).item(),
np.min(transformed_points[:, 1]).item(),
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_)
out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox
return torch.stack(
[transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()]
).reshape(bounding_boxes.shape)
def sample_inputs_convert_bounding_box_format():
formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)
def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
),
)
_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])
def sample_inputs_resized_crop_image_tensor():
for image_loader in make_image_loaders():
yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0])
@pil_reference_wrapper
def reference_resized_crop_image_tensor(*args, **kwargs):
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
}:
raise pytest.UsageError("Anti-aliasing is always active in PIL")
return F._resized_crop_image_pil(*args, **kwargs)
def reference_inputs_resized_crop_image_tensor():
for image_loader, interpolation, params in itertools.product(
make_image_loaders_for_interpolation(),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
_RESIZED_CROP_PARAMS,
):
yield ArgsKwargs(
image_loader,
interpolation=interpolation,
antialias=interpolation
in {
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
},
**params,
)
def sample_inputs_resized_crop_bounding_boxes():
for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **_RESIZED_CROP_PARAMS[0])
def sample_inputs_resized_crop_mask():
for mask_loader in make_mask_loaders():
yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0])
def sample_inputs_resized_crop_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.resized_crop_image,
sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**pil_reference_pixel_difference(3, mae=True),
**float32_vs_uint8_pixel_difference(3, mae=True),
},
),
KernelInfo(
F.resized_crop_bounding_boxes,
sample_inputs_fn=sample_inputs_resized_crop_bounding_boxes,
),
KernelInfo(
F.resized_crop_mask,
sample_inputs_fn=sample_inputs_resized_crop_mask,
),
KernelInfo(
F.resized_crop_video,
sample_inputs_fn=sample_inputs_resized_crop_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
_PAD_PARAMS = combinations_grid(
padding=[[1], [1, 1], [1, 1, 2, 2]],
padding_mode=["constant", "symmetric", "edge", "reflect"],
)
def sample_inputs_pad_image_tensor():
make_pad_image_loaders = functools.partial(
make_image_loaders, sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]
)
for image_loader, padding in itertools.product(
make_pad_image_loaders(),
[1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]],
):
yield ArgsKwargs(image_loader, padding=padding)
for image_loader in make_pad_image_loaders():
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, padding=[1], fill=fill)
for image_loader, padding_mode in itertools.product(
# We branch for non-constant padding and integer inputs
make_pad_image_loaders(dtypes=[torch.uint8]),
["constant", "symmetric", "edge", "reflect"],
):
yield ArgsKwargs(image_loader, padding=[1], padding_mode=padding_mode)
# `torch.nn.functional.pad` does not support symmetric padding, and thus we have a custom implementation. Besides
# negative padding, this is already handled by the inputs above.
for image_loader in make_pad_image_loaders():
yield ArgsKwargs(image_loader, padding=[-1], padding_mode="symmetric")
def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
):
for fill in get_fills(
num_channels=image_loader.num_channels,
dtype=image_loader.dtype,
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs(image_loader, fill=fill, **params)
def sample_inputs_pad_bounding_boxes():
for bounding_boxes_loader, padding in itertools.product(
make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
):
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
padding=padding,
padding_mode="constant",
)
def sample_inputs_pad_mask():
for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
yield ArgsKwargs(mask_loader, padding=[1])
def reference_inputs_pad_mask():
for mask_loader, fill, params in itertools.product(
make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
):
yield ArgsKwargs(mask_loader, fill=fill, **params)
def sample_inputs_pad_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, padding=[1])
def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding, padding_mode):
left, right, top, bottom = _parse_pad_padding(padding)
affine_matrix = np.array(
[
[1, 0, left],
[0, 1, top],
],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
)
height = canvas_size[0] + top + bottom
width = canvas_size[1] + left + right
expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix
)
return expected_bboxes, (height, width)
def reference_inputs_pad_bounding_boxes():
for bounding_boxes_loader, padding in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
):
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
padding=padding,
padding_mode="constant",
)
def pad_xfail_jit_fill_condition(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
if not isinstance(fill, (list, tuple)):
return False
elif isinstance(fill, tuple):
return True
else: # isinstance(fill, list):
return all(isinstance(f, int) for f in fill)
KERNEL_INFOS.extend(
[
KernelInfo(
F.pad_image,
sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit(
"F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
),
],
),
KernelInfo(
F.pad_bounding_boxes,
sample_inputs_fn=sample_inputs_pad_bounding_boxes,
reference_fn=reference_pad_bounding_boxes,
reference_inputs_fn=reference_inputs_pad_bounding_boxes,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
],
),
KernelInfo(
F.pad_mask,
sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
),
KernelInfo(
F.pad_video,
sample_inputs_fn=sample_inputs_pad_video,
),
]
)
_PERSPECTIVE_COEFFS = [
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(
image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
)
yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
def reference_inputs_perspective_image_tensor():
for image_loader, coefficients, interpolation in itertools.product(
make_image_loaders_for_interpolation(),
_PERSPECTIVE_COEFFS,
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
],
):
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs(
image_loader,
startpoints=None,
endpoints=None,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
def sample_inputs_perspective_bounding_boxes():
for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
startpoints=None,
endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0],
)
format = tv_tensors.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format)
yield ArgsKwargs(
loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
)
def sample_inputs_perspective_mask():
for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
def reference_inputs_perspective_mask():
for mask_loader, perspective_coeffs in itertools.product(
make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
):
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
def sample_inputs_perspective_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
KERNEL_INFOS.extend(
[
KernelInfo(
F.perspective_image,
sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={
**pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.perspective_bounding_boxes,
sample_inputs_fn=sample_inputs_perspective_bounding_boxes,
closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
},
),
KernelInfo(
F.perspective_mask,
sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask,
float32_vs_uint8=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
},
),
KernelInfo(
F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
),
]
)
def _get_elastic_displacement(canvas_size):
return torch.rand(1, *canvas_size, 2)
def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
displacement = _get_elastic_displacement(image_loader.canvas_size)
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
def reference_inputs_elastic_image_tensor():
for image_loader, interpolation in itertools.product(
make_image_loaders_for_interpolation(),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
):
displacement = _get_elastic_displacement(image_loader.canvas_size)
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
def sample_inputs_elastic_bounding_boxes():
for bounding_boxes_loader in make_bounding_box_loaders():
displacement = _get_elastic_displacement(bounding_boxes_loader.canvas_size)
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
displacement=displacement,
)
def sample_inputs_elastic_mask():
for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
displacement = _get_elastic_displacement(mask_loader.shape[-2:])
yield ArgsKwargs(mask_loader, displacement=displacement)
def sample_inputs_elastic_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
displacement = _get_elastic_displacement(video_loader.shape[-2:])
yield ArgsKwargs(video_loader, displacement=displacement)
KERNEL_INFOS.extend(
[
KernelInfo(
F.elastic_image,
sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_inputs_fn=reference_inputs_elastic_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={
**float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.elastic_bounding_boxes,
sample_inputs_fn=sample_inputs_elastic_bounding_boxes,
),
KernelInfo(
F.elastic_mask,
sample_inputs_fn=sample_inputs_elastic_mask,
),
KernelInfo(
F.elastic_video,
sample_inputs_fn=sample_inputs_elastic_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def sample_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
[
# valid `output_size` types for which cropping is applied to both dimensions
*[5, (4,), (2, 3), [6], [3, 2]],
# `output_size`'s for which at least one dimension needs to be padded
*[[4, 18], [17, 5], [17, 18]],
],
):
yield ArgsKwargs(image_loader, output_size=output_size)
def reference_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product(
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
_CENTER_CROP_OUTPUT_SIZES,
):
yield ArgsKwargs(image_loader, output_size=output_size)
def sample_inputs_center_crop_bounding_boxes():
for bounding_boxes_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
output_size=output_size,
)
def sample_inputs_center_crop_mask():
for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
height, width = mask_loader.shape[-2:]
yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
def reference_inputs_center_crop_mask():
for mask_loader, output_size in itertools.product(
make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
):
yield ArgsKwargs(mask_loader, output_size=output_size)
def sample_inputs_center_crop_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
height, width = video_loader.shape[-2:]
yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))
KERNEL_INFOS.extend(
[
KernelInfo(
F.center_crop_image,
sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
float32_vs_uint8=True,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
),
KernelInfo(
F.center_crop_bounding_boxes,
sample_inputs_fn=sample_inputs_center_crop_bounding_boxes,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
),
KernelInfo(
F.center_crop_mask,
sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
float32_vs_uint8=True,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
),
KernelInfo(
F.center_crop_video,
sample_inputs_fn=sample_inputs_center_crop_video,
),
]
)
def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
yield ArgsKwargs(image_loader, kernel_size=kernel_size)
for image_loader, sigma in itertools.product(
make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
):
yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
def sample_inputs_gaussian_blur_video():
for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
yield ArgsKwargs(video_loader, kernel_size=[3, 3])
KERNEL_INFOS.extend(
[
KernelInfo(
F.gaussian_blur_image,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
),
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
memory_format=memory_format, copy=True
)
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
canvas_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8],
["GRAY", "RGB"],
[
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
memory_format=memory_format, copy=True
),
lambda shape, dtype, device, memory_format: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
).to(memory_format=memory_format, copy=True),
*[
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
for low_factor, high_factor in [
(0.0, 0.25),
(0.25, 0.75),
(0.75, 1.0),
]
],
*[
functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
for alpha, beta in [
(0.5, 0.5),
(2, 2),
(2, 5),
(5, 2),
]
],
],
):
image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *canvas_size), dtype=dtype)
yield ArgsKwargs(image_loader)
def sample_inputs_equalize_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader)
KERNEL_INFOS.extend(
[
KernelInfo(
F.equalize_image,
kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F._equalize_image_pil),
float32_vs_uint8=True,
reference_inputs_fn=reference_inputs_equalize_image_tensor,
),
KernelInfo(
F.equalize_video,
sample_inputs_fn=sample_inputs_equalize_video,
),
]
)
def sample_inputs_invert_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_invert_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader)
KERNEL_INFOS.extend(
[
KernelInfo(
F.invert_image,
kernel_name="invert_image_tensor",
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F._invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor,
float32_vs_uint8=True,
),
KernelInfo(
F.invert_video,
sample_inputs_fn=sample_inputs_invert_video,
),
]
)
_POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_POSTERIZE_BITS,
):
yield ArgsKwargs(image_loader, bits=bits)
def sample_inputs_posterize_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.posterize_image,
kernel_name="posterize_image_tensor",
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F._posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.posterize_video,
sample_inputs_fn=sample_inputs_posterize_video,
),
]
)
def _get_solarize_thresholds(dtype):
for factor in [0.1, 0.5]:
max_value = get_max_value(dtype)
yield (float if dtype.is_floating_point else int)(max_value * factor)
def sample_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))
def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold)
def uint8_to_float32_threshold_adapter(other_args, kwargs):
return other_args, dict(threshold=kwargs["threshold"] / 255)
def sample_inputs_solarize_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))
KERNEL_INFOS.extend(
[
KernelInfo(
F.solarize_image,
kernel_name="solarize_image_tensor",
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F._solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor,
float32_vs_uint8=uint8_to_float32_threshold_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.solarize_video,
sample_inputs_fn=sample_inputs_solarize_video,
),
]
)
def sample_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_autocontrast_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader)
KERNEL_INFOS.extend(
[
KernelInfo(
F.autocontrast_image,
kernel_name="autocontrast_image_tensor",
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F._autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.autocontrast_video,
sample_inputs_fn=sample_inputs_autocontrast_video,
),
]
)
_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE, (2, 2)],
color_spaces=("GRAY", "RGB"),
):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_SHARPNESS_FACTORS,
):
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
def sample_inputs_adjust_sharpness_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_sharpness_image,
kernel_name="adjust_sharpness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(2),
),
KernelInfo(
F.adjust_sharpness_video,
sample_inputs_fn=sample_inputs_adjust_sharpness_video,
),
]
)
_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_CONTRAST_FACTORS,
):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
def sample_inputs_adjust_contrast_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_contrast_image,
kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
**cuda_vs_cpu_pixel_difference(),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
},
),
KernelInfo(
F.adjust_contrast_video,
sample_inputs_fn=sample_inputs_adjust_contrast_video,
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
},
),
]
)
_ADJUST_GAMMA_GAMMAS_GAINS = [
(0.5, 2.0),
(0.0, 1.0),
]
def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_GAMMA_GAMMAS_GAINS,
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
def sample_inputs_adjust_gamma_video():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_gamma_image,
kernel_name="adjust_gamma_image_tensor",
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.adjust_gamma_video,
sample_inputs_fn=sample_inputs_adjust_gamma_video,
),
]
)
_ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_HUE_FACTORS,
):
yield ArgsKwargs(image_loader, hue_factor=hue_factor)
def sample_inputs_adjust_hue_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_hue_image,
kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(2, mae=True),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.adjust_hue_video,
sample_inputs_fn=sample_inputs_adjust_hue_video,
),
]
)
_ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_SATURATION_FACTORS,
):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
def sample_inputs_adjust_saturation_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_saturation_image,
kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
**cuda_vs_cpu_pixel_difference(),
},
),
KernelInfo(
F.adjust_saturation_video,
sample_inputs_fn=sample_inputs_adjust_saturation_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
def sample_inputs_clamp_bounding_boxes():
for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_boxes_loader,
format=bounding_boxes_loader.format,
canvas_size=bounding_boxes_loader.canvas_size,
)
KERNEL_INFOS.append(
KernelInfo(
F.clamp_bounding_boxes,
sample_inputs_fn=sample_inputs_clamp_bounding_boxes,
logs_usage=True,
)
)
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]
def _get_five_ten_crop_canvas_size(size):
if isinstance(size, int):
crop_height = crop_width = size
elif len(size) == 1:
crop_height = crop_width = size[0]
else:
crop_height, crop_width = size
return 2 * crop_height, 2 * crop_width
def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_canvas_size(size)],
color_spaces=["RGB"],
dtypes=[torch.float32],
):
yield ArgsKwargs(image_loader, size=size)
def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size)
def sample_inputs_five_crop_video():
size = _FIVE_TEN_CROP_SIZES[0]
for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
yield ArgsKwargs(video_loader, size=size)
def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_canvas_size(size)],
color_spaces=["RGB"],
dtypes=[torch.float32],
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
def sample_inputs_ten_crop_video():
size = _FIVE_TEN_CROP_SIZES[0]
for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
yield ArgsKwargs(video_loader, size=size)
def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)(
F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output
)
return wrapper
_common_five_ten_crop_marks = [
xfail_jit_python_scalar_arg("size"),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]
KERNEL_INFOS.extend(
[
KernelInfo(
F.five_crop_image,
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
),
KernelInfo(
F.five_crop_video,
sample_inputs_fn=sample_inputs_five_crop_video,
test_marks=_common_five_ten_crop_marks,
),
KernelInfo(
F.ten_crop_image,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
),
KernelInfo(
F.ten_crop_video,
sample_inputs_fn=sample_inputs_ten_crop_video,
test_marks=_common_five_ten_crop_marks,
),
]
)
_NORMALIZE_MEANS_STDS = [
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
(0.5, 2.0),
]
def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS,
):
yield ArgsKwargs(image_loader, mean=mean, std=std)
def reference_normalize_image_tensor(image, mean, std, inplace=False):
mean = torch.tensor(mean).view(-1, 1, 1)
std = torch.tensor(std).view(-1, 1, 1)
sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
return sub(image, mean).div_(std)
def reference_inputs_normalize_image_tensor():
yield ArgsKwargs(
make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
mean=[0.5, 0.5, 0.5],
std=[1.0, 1.0, 1.0],
)
def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[3], dtypes=[torch.float32]
):
yield ArgsKwargs(video_loader, mean=mean, std=std)
KERNEL_INFOS.extend(
[
KernelInfo(
F.normalize_image,
kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor,
reference_fn=reference_normalize_image_tensor,
reference_inputs_fn=reference_inputs_normalize_image_tensor,
test_marks=[
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
],
),
KernelInfo(
F.normalize_video,
sample_inputs_fn=sample_inputs_normalize_video,
),
]
)
def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
yield ArgsKwargs(video_loader, num_samples=2)
def reference_uniform_temporal_subsample_video(x, num_samples):
# Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[-4]
assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, -4, indices)
def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[10]
):
for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples)
KERNEL_INFOS.append(
KernelInfo(
F.uniform_temporal_subsample_video,
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
)
)
"""
As the name implies, these are legacy utilities that are hopefully removed soon. The future of
transforms v2 testing is in test/test_transforms_v2_refactored.py. All new test should be
implemented there and must not use any of the utilities here.
The following legacy modules depend on this module
- transforms_v2_kernel_infos.py
- transforms_v2_dispatcher_infos.py
- test_transforms_v2_functional.py
- test_transforms_v2_consistency.py
- test_transforms.py
When all the logic is ported from the files above to test_transforms_v2_refactored.py, delete
all the legacy modules including this one and drop the _refactored prefix from the name.
"""
import collections.abc
import dataclasses
import enum
import itertools
import pathlib
from collections import defaultdict
from typing import Callable, Sequence, Tuple, Union
import PIL.Image
import pytest
import torch
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
def combinations_grid(**kwargs):
"""Creates a grid of input combinations.
Each element in the returned sequence is a dictionary containing one possible combination as values.
Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
DEFAULT_SIZE = (17, 11)
NUM_CHANNELS_MAP = {
"GRAY": 1,
"GRAY_ALPHA": 2,
"RGB": 3,
"RGBA": 4,
}
def make_image(
size=DEFAULT_SIZE,
*,
color_space="RGB",
batch_dims=(),
dtype=None,
device="cpu",
memory_format=torch.contiguous_format,
):
num_channels = NUM_CHANNELS_MAP[color_space]
dtype = dtype or torch.uint8
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(
(*batch_dims, num_channels, *size),
low=0,
high=max_value,
dtype=dtype,
device=device,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs):
return make_image(*args, **kwargs).as_subclass(torch.Tensor)
def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))
def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
batch_dims=(),
dtype=None,
device="cpu",
):
def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims):
return tv_tensors.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
else:
raise ValueError(f"Format {format} is not supported")
return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, num_objects, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
device=device,
)
)
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
high=num_categories,
dtype=dtype or torch.uint8,
device=device,
)
)
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
DEFAULT_SQUARE_SPATIAL_SIZE = 15
DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33)
DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9)
DEFAULT_SPATIAL_SIZES = (
DEFAULT_LANDSCAPE_SPATIAL_SIZE,
DEFAULT_PORTRAIT_SPATIAL_SIZE,
DEFAULT_SQUARE_SPATIAL_SIZE,
)
def _parse_size(size, *, name="size"):
if size == "random":
raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0:
return (size, size)
elif (
isinstance(size, collections.abc.Sequence)
and len(size) == 2
and all(isinstance(length, int) and length > 0 for length in size)
):
return tuple(size)
else:
raise pytest.UsageError(
f"'{name}' can either be `'random'`, a positive integer, or a sequence of two positive integers,"
f"but got {size} instead."
)
def get_num_channels(color_space):
num_channels = NUM_CHANNELS_MAP.get(color_space)
if not num_channels:
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}")
return num_channels
VALID_EXTRA_DIMS = ((), (4,), (2, 3))
DEGENERATE_BATCH_DIMS = ((0,), (5, 0), (0, 5))
DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)
def from_loader(loader_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loader = loader_fn(*args, **kwargs)
return loader.load(device)
return wrapper
def from_loaders(loaders_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loaders = loaders_fn(*args, **kwargs)
for loader in loaders:
yield loader.load(device)
return wrapper
@dataclasses.dataclass
class TensorLoader:
fn: Callable[[Sequence[int], torch.dtype, Union[str, torch.device]], torch.Tensor]
shape: Sequence[int]
dtype: torch.dtype
def load(self, device):
return self.fn(self.shape, self.dtype, device)
@dataclasses.dataclass
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self):
self.spatial_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3]
def load(self, device):
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)
def make_image_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*,
color_space="RGB",
extra_dims=(),
dtype=torch.float32,
constant_alpha=True,
memory_format=torch.contiguous_format,
):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_size(size)
num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format):
*batch_dims, _, height, width = shape
return make_image(
(height, width),
color_space=color_space,
batch_dims=batch_dims,
dtype=dtype,
device=device,
memory_format=memory_format,
)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)
def make_image_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
"GRAY",
"GRAY_ALPHA",
"RGB",
"RGBA",
),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.float64, torch.uint8),
constant_alpha=True,
):
for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes):
yield make_image_loader(**params, constant_alpha=constant_alpha)
make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_size(size)
num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format):
height, width = shape[-2:]
image_pil = (
PIL.Image.open(pathlib.Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
.resize((width, height))
.convert(
{
"GRAY": "L",
"GRAY_ALPHA": "LA",
"RGB": "RGB",
"RGBA": "RGBA",
}[color_space]
)
)
image_tensor = to_image(image_pil)
if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
return tv_tensors.Image(image_tensor)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)
def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
memory_formats=(torch.contiguous_format, torch.channels_last),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
yield make_image_loader_for_interpolation(**params)
@dataclasses.dataclass
class BoundingBoxesLoader(TensorLoader):
format: tv_tensors.BoundingBoxFormat
spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self):
self.canvas_size = self.spatial_size
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format]
spatial_size = _parse_size(spatial_size, name="spatial_size")
def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
if num_coordinates != 4:
raise pytest.UsageError()
return make_bounding_boxes(
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
)
return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=spatial_size)
def make_bounding_box_loaders(
*,
extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
formats=tuple(tv_tensors.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
make_multiple_bounding_boxes = from_loaders(make_bounding_box_loaders)
class MaskLoader(TensorLoader):
pass
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_size(size)
def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape
return make_detection_mask(
(height, width), num_objects=num_objects, batch_dims=batch_dims, dtype=dtype, device=device
)
return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype)
def make_detection_mask_loaders(
sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, 5),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
):
for params in combinations_grid(size=sizes, num_objects=num_objects, extra_dims=extra_dims, dtype=dtypes):
yield make_detection_mask_loader(**params)
make_detection_masks = from_loaders(make_detection_mask_loaders)
def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = _parse_size(size)
def fn(shape, dtype, device):
*batch_dims, height, width = shape
return make_segmentation_mask(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
def make_segmentation_mask_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
num_categories=(1, 2, 10),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
):
for params in combinations_grid(size=sizes, num_categories=num_categories, extra_dims=extra_dims, dtype=dtypes):
yield make_segmentation_mask_loader(**params)
make_segmentation_masks = from_loaders(make_segmentation_mask_loaders)
def make_mask_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, 5),
num_categories=(1, 2, 10),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
):
yield from make_detection_mask_loaders(sizes=sizes, num_objects=num_objects, extra_dims=extra_dims, dtypes=dtypes)
yield from make_segmentation_mask_loaders(
sizes=sizes, num_categories=num_categories, extra_dims=extra_dims, dtypes=dtypes
)
make_masks = from_loaders(make_mask_loaders)
class VideoLoader(ImageLoader):
pass
def make_video_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*,
color_space="RGB",
num_frames=3,
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_size(size)
def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape
return make_video(
(height, width),
num_frames=num_frames,
batch_dims=batch_dims,
color_space=color_space,
dtype=dtype,
device=device,
memory_format=memory_format,
)
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
def make_video_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
"GRAY",
"RGB",
),
num_frames=(1, 0, 3),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8, torch.float32, torch.float64),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
):
yield make_video_loader(**params)
make_videos = from_loaders(make_video_loaders)
class TestMark:
def __init__(
self,
# Tuple of test class name and test function name that identifies the test the mark is applied to. If there is
# no test class, i.e. a standalone test function, use `None`.
test_id,
# `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail`
mark,
*,
# Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be
# applied. If omitted, defaults to always apply.
condition=None,
):
self.test_id = test_id
self.mark = mark
self.condition = condition or (lambda args_kwargs: True)
def mark_framework_limitation(test_id, reason, condition=None):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition)
class InfoBase:
def __init__(
self,
*,
# Identifier if the info that shows up the parametrization.
id,
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
test_marks=None,
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. Keys are a 3-tuple of `test_id` (see
# `TestMark`), the dtype, and the device.
closeness_kwargs=None,
):
self.id = id
self.test_marks = test_marks or []
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
self.closeness_kwargs = closeness_kwargs or dict()
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
def get_closeness_kwargs(self, test_id, *, dtype, device):
if not (isinstance(test_id, tuple) and len(test_id) == 2):
msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
if callable(test_id):
msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
else:
msg += f", but got {test_id} instead."
raise pytest.UsageError(msg)
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def load(self, device="cpu"):
return ArgsKwargs(
*(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args),
**{
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg
for keyword, arg in self.kwargs.items()
},
)
def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 30:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}"
else:
return repr(obj)
if args or kwargs:
postfix = "\n".join(
[
"",
"Failure happened for the following parameters:",
"",
*[to_str(arg) for arg in args],
*[f"{name}={to_str(kwarg)}" for name, kwarg in kwargs.items()],
]
)
else:
postfix = ""
def wrapper(msg):
return msg + postfix
return wrapper
import os
import warnings
from modulefinder import Module
import torch
from torchvision import datasets, io, models, ops, transforms, utils
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
from .extension import _HAS_OPS
......@@ -71,11 +72,16 @@ def set_video_backend(backend):
backend, please compile torchvision from source.
"""
global _video_backend
if backend not in ["pyav", "video_reader"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
warnings.warn(message)
raise RuntimeError(message)
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
# TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
else:
_video_backend = backend
......@@ -93,3 +99,9 @@ def get_video_backend():
def _is_tracing():
return torch._C._get_tracing_state()
def disable_beta_transforms_warning():
# Noop, only exists to avoid breaking existing code.
# See https://github.com/pytorch/vision/issues/7896
pass
......@@ -28,7 +28,6 @@ def _get_extension_path(lib_name):
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
......@@ -37,14 +36,7 @@ def _get_extension_path(lib_name):
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err
os.add_dll_directory(lib_dir)
kernel32.SetErrorMode(prev_error_mode)
......
import functools
import torch
import torch.library
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn
return wrapper
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
......@@ -312,6 +312,8 @@ bool Decoder::init(
}
}
av_dict_set_int(&options, "probesize", params_.probeSize, 0);
interrupted_ = false;
// ffmpeg avformat_open_input call can hang if media source doesn't respond
......
......@@ -165,7 +165,7 @@ struct MediaFormat {
struct DecoderParameters {
// local file, remote file, http url, rtmp stream uri, etc. anything that
// ffmpeg can recognize
std::string uri;
std::string uri{std::string()};
// timeout on getting bytes for decoding
size_t timeoutMs{1000};
// logging level, default AV_LOG_PANIC
......@@ -213,6 +213,12 @@ struct DecoderParameters {
// Skip packets that fail with EPERM errors and continue decoding.
bool skipOperationNotPermittedPackets{false};
// probing size in bytes, i.e. the size of the data to analyze to get stream
// information. A higher value will enable detecting more information in case
// it is dispersed into the stream, but will increase latency. Must be an
// integer not lesser than 32. It is 5000000 by default.
int64_t probeSize{5000000};
};
struct DecoderHeader {
......@@ -295,7 +301,7 @@ struct DecoderMetadata {
};
/**
* Abstract class for decoding media bytes
* It has two diffrent modes. Internal media bytes retrieval for given uri and
* It has two different modes. Internal media bytes retrieval for given uri and
* external media bytes provider in case of memory streams
*/
class MediaDecoder {
......
......@@ -61,7 +61,7 @@ DecoderInCallback MemoryBuffer::getCallback(
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - supported
// seek capability, yes - supported
return 0;
}
return object.seek(size, whence);
......
......@@ -368,7 +368,7 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithFullRead) {
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
// seek capability, yes - no
return -1;
}
return object.seek(size, whence);
......@@ -408,7 +408,7 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithPartialRead) {
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
// seek capability, yes - no
return -1;
}
return object.seek(size, whence);
......
......@@ -181,6 +181,23 @@ bool VideoSampler::init(const SamplerParameters& params) {
// set output format
params_ = params;
if (params.in.video.format == AV_PIX_FMT_YUV420P) {
/* When the video width and height are not multiples of 8,
* and there is no size change in the conversion,
* a blurry screen will appear on the right side
* This problem was discovered in 2012 and
* continues to exist in version 4.1.3 in 2019
* This problem can be avoided by increasing SWS_ACCURATE_RND
* details https://trac.ffmpeg.org/ticket/1582
*/
if ((params.in.video.width & 0x7) || (params.in.video.height & 0x7)) {
VLOG(1) << "The width " << params.in.video.width << " and height "
<< params.in.video.height << " the image is not a multiple of 8, "
<< "the decoding speed may be reduced";
swsFlags_ |= SWS_ACCURATE_RND;
}
}
scaleContext_ = sws_getContext(
params.in.video.width,
params.in.video.height,
......
......@@ -7,6 +7,8 @@ namespace vision {
namespace image {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......
......@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}
inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
}
void convert_line_cmyk_to_rgb(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* rgb_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
}
}
inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}
void convert_line_cmyk_to_gray(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* gray_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
int r = clamped_cmyk_rgb_convert(k, 255 - c);
int g = clamped_cmyk_rgb_convert(k, 255 - m);
int b = clamped_cmyk_rgb_convert(k, 255 - y);
gray_line[i] = rgb_to_gray(r, g, b);
}
}
} // namespace
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
......@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_read_header(&cinfo, TRUE);
int channels = cinfo.num_components;
bool cmyk_to_rgb_or_gray = false;
if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
channels = 1;
break;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
channels = 3;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
......@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
}
while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
jpeg_read_scanlines(&cinfo, &ptr, 1);
if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
if (channels == 3) {
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
} else if (channels == 1) {
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
}
} else {
jpeg_read_scanlines(&cinfo, &ptr, 1);
}
ptr += stride;
}
......@@ -152,8 +229,23 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
}
#endif // #if !JPEG_FOUND
int64_t _jpeg_version() {
#if JPEG_FOUND
return JPEG_LIB_VERSION;
#else
return -1;
#endif
}
bool _is_compiled_against_turbo() {
#ifdef LIBJPEG_TURBO_VERSION
return true;
#else
return false;
#endif
}
} // namespace image
} // namespace vision
......@@ -10,5 +10,8 @@ C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
C10_EXPORT int64_t _jpeg_version();
C10_EXPORT bool _is_compiled_against_turbo();
} // namespace image
} // namespace vision
......@@ -49,6 +49,7 @@ torch::Tensor decode_png(
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Internal error.");
}
TORCH_CHECK(datap_len >= 8, "Content is too small for png!")
auto is_png = !png_sig_cmp(datap, 0, 8);
TORCH_CHECK(is_png, "Content is not png!")
......
......@@ -19,15 +19,18 @@ PyMODINIT_FUNC PyInit_image(void) {
namespace vision {
namespace image {
static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg", &decode_jpeg)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);
static auto registry =
torch::RegisterOperators()
.op("image::decode_png", &decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg", &decode_jpeg)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda)
.op("image::_jpeg_version", &_jpeg_version)
.op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo);
} // namespace image
} // namespace vision
......@@ -156,14 +156,34 @@ void Video::_getDecoderParams(
} // _get decoder params
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
void Video::initFromFile(
std::string videoPath,
std::string stream,
int64_t numThreads) {
TORCH_CHECK(!initialized, "Video object can only be initialized once");
initialized = true;
params.uri = videoPath;
_init(stream, numThreads);
}
void Video::initFromMemory(
torch::Tensor videoTensor,
std::string stream,
int64_t numThreads) {
TORCH_CHECK(!initialized, "Video object can only be initialized once");
initialized = true;
callback = MemoryBuffer::getCallback(
videoTensor.data_ptr<uint8_t>(), videoTensor.size(0));
_init(stream, numThreads);
}
void Video::_init(std::string stream, int64_t numThreads) {
// set number of threads global
numThreads_ = numThreads;
// parse stream information
current_stream = _parseStream(stream);
// note that in the initial call we want to get all streams
Video::_getDecoderParams(
_getDecoderParams(
0, // video start
0, // headerOnly
std::get<0>(current_stream), // stream info - remove that
......@@ -175,11 +195,6 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
std::string logMessage, logType;
// TODO: add read from memory option
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
// locals
std::vector<double> audioFPS, videoFPS;
std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
......@@ -190,7 +205,8 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
c10::Dict<std::string, std::vector<double>> subsMetadata;
// callback and metadata defined in struct
succeeded = decoder.init(params, std::move(callback), &metadata);
DecoderInCallback tmp_callback = callback;
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
if (succeeded) {
for (const auto& header : metadata) {
double fps = double(header.fps);
......@@ -225,16 +241,24 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
streamsMetadata.insert("subtitles", subsMetadata);
streamsMetadata.insert("cc", ccMetadata);
succeeded = Video::setCurrentStream(stream);
succeeded = setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (std::get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << std::get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
}
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
if (!videoPath.empty()) {
initFromFile(videoPath, stream, numThreads);
}
} // video
bool Video::setCurrentStream(std::string stream = "video") {
TORCH_CHECK(initialized, "Video object has to be initialized first");
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
current_stream = _parseStream(stream);
}
......@@ -256,19 +280,23 @@ bool Video::setCurrentStream(std::string stream = "video") {
);
// callback and metadata defined in Video.h
return (decoder.init(params, std::move(callback), &metadata));
DecoderInCallback tmp_callback = callback;
return (decoder.init(params, std::move(tmp_callback), &metadata));
}
std::tuple<std::string, int64_t> Video::getCurrentStream() const {
TORCH_CHECK(initialized, "Video object has to be initialized first");
return current_stream;
}
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
getStreamMetadata() const {
TORCH_CHECK(initialized, "Video object has to be initialized first");
return streamsMetadata;
}
void Video::Seek(double ts, bool fastSeek = false) {
TORCH_CHECK(initialized, "Video object has to be initialized first");
// initialize the class variables used for seeking and retrurn
_getDecoderParams(
ts, // video start
......@@ -282,20 +310,23 @@ void Video::Seek(double ts, bool fastSeek = false) {
);
// callback and metadata defined in Video.h
succeeded = decoder.init(params, std::move(callback), &metadata);
DecoderInCallback tmp_callback = callback;
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
}
std::tuple<torch::Tensor, double> Video::Next() {
TORCH_CHECK(initialized, "Video object has to be initialized first");
// if failing to decode simply return a null tensor (note, should we
// raise an exeption?)
// raise an exception?)
double frame_pts_s;
torch::Tensor outFrame = torch::zeros({0}, torch::kByte);
// decode single frame
DecoderOutputMessage out;
int64_t res = decoder.decode(&out, decoderTimeoutMs);
// if successfull
// if successful
if (res == 0) {
frame_pts_s = double(double(out.header.pts) * 1e-6);
......@@ -345,6 +376,8 @@ std::tuple<torch::Tensor, double> Video::Next() {
static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string, int64_t>())
.def("init_from_file", &Video::initFromFile)
.def("init_from_memory", &Video::initFromMemory)
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
......
......@@ -19,7 +19,19 @@ struct Video : torch::CustomClassHolder {
int64_t numThreads_{0};
public:
Video(std::string videoPath, std::string stream, int64_t numThreads);
Video(
std::string videoPath = std::string(),
std::string stream = std::string("video"),
int64_t numThreads = 0);
void initFromFile(
std::string videoPath,
std::string stream,
int64_t numThreads);
void initFromMemory(
torch::Tensor videoTensor,
std::string stream,
int64_t numThreads);
std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const;
......@@ -30,10 +42,16 @@ struct Video : torch::CustomClassHolder {
private:
bool succeeded = false; // decoder init flag
// seekTS and doSeek act as a flag - if it's not set, next function simply
// retruns the next frame. If it's set, we look at the global seek
// time in comination with any_frame settings
// returns the next frame. If it's set, we look at the global seek
// time in combination with any_frame settings
double seekTS = -1;
bool initialized = false;
void _init(
std::string stream,
int64_t numThreads); // expects params.uri OR callback to be set
void _getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
......
#pragma once
#ifdef _WIN32
#if defined(_WIN32) && !defined(TORCHVISION_BUILD_STATIC_LIBS)
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
......
......@@ -15,8 +15,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
......@@ -24,10 +24,10 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->saved_data["input_shape"] = input.sym_sizes();
ctx->save_for_backward({rois});
at::AutoDispatchBelowADInplaceOrView g;
auto result = roi_align(
auto result = roi_align_symint(
input,
rois,
spatial_scale,
......@@ -44,17 +44,17 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_roi_align_backward(
auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_roi_align_backward_symint(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toSymInt(),
input_shape[0].get().toSymInt(),
input_shape[1].get().toSymInt(),
input_shape[2].get().toSymInt(),
input_shape[3].get().toSymInt(),
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {
......@@ -77,16 +77,16 @@ class ROIAlignBackwardFunction
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_roi_align_backward(
auto result = detail::_roi_align_backward_symint(
grad,
rois,
spatial_scale,
......@@ -112,8 +112,8 @@ at::Tensor roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignFunction::apply(
......@@ -130,12 +130,12 @@ at::Tensor roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignBackwardFunction::apply(
......
......@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment