Unverified Commit 17969eba authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enable arbitrary batch size for all prototype kernels (#6726)

* enable arbitrary batch size for all prototype kernels

* put back perspective dispatcher
parent 019139f7
...@@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition): ...@@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition):
] ]
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)
DISPATCHER_INFOS = [ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.horizontal_flip, F.horizontal_flip,
...@@ -260,7 +254,6 @@ DISPATCHER_INFOS = [ ...@@ -260,7 +254,6 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[ test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_degenerate_or_multi_batch_dims,
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -271,7 +264,6 @@ DISPATCHER_INFOS = [ ...@@ -271,7 +264,6 @@ DISPATCHER_INFOS = [
features.Mask: F.elastic_mask, features.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=xfails_degenerate_or_multi_batch_dims,
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
...@@ -294,7 +286,6 @@ DISPATCHER_INFOS = [ ...@@ -294,7 +286,6 @@ DISPATCHER_INFOS = [
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
*xfails_degenerate_or_multi_batch_dims,
], ],
), ),
DispatcherInfo( DispatcherInfo(
......
...@@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition): ...@@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition):
] ]
xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)
KERNEL_INFOS = [] KERNEL_INFOS = []
...@@ -1156,7 +1150,6 @@ KERNEL_INFOS.extend( ...@@ -1156,7 +1150,6 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_box,
...@@ -1168,7 +1161,6 @@ KERNEL_INFOS.extend( ...@@ -1168,7 +1161,6 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.perspective_video, F.perspective_video,
...@@ -1239,7 +1231,6 @@ KERNEL_INFOS.extend( ...@@ -1239,7 +1231,6 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_box,
...@@ -1251,7 +1242,6 @@ KERNEL_INFOS.extend( ...@@ -1251,7 +1242,6 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask, reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.elastic_video, F.elastic_video,
...@@ -1379,7 +1369,6 @@ KERNEL_INFOS.extend( ...@@ -1379,7 +1369,6 @@ KERNEL_INFOS.extend(
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
*xfails_image_degenerate_or_multi_batch_dims,
], ],
), ),
KernelInfo( KernelInfo(
......
...@@ -882,7 +882,23 @@ def perspective_image_tensor( ...@@ -882,7 +882,23 @@ def perspective_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) if image.numel() == 0:
return image
shape = image.shape
if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
@torch.jit.unused @torch.jit.unused
...@@ -1007,25 +1023,7 @@ def perspective_video( ...@@ -1007,25 +1023,7 @@ def perspective_video(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def perspective( def perspective(
...@@ -1048,7 +1046,23 @@ def elastic_image_tensor( ...@@ -1048,7 +1046,23 @@ def elastic_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) if image.numel() == 0:
return image
shape = image.shape
if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
@torch.jit.unused @torch.jit.unused
...@@ -1128,25 +1142,7 @@ def elastic_video( ...@@ -1128,25 +1142,7 @@ def elastic_video(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def elastic( def elastic(
......
...@@ -56,7 +56,23 @@ def gaussian_blur_image_tensor( ...@@ -56,7 +56,23 @@ def gaussian_blur_image_tensor(
if s <= 0.0: if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}") raise ValueError(f"sigma should have positive values. Got {sigma}")
return _FT.gaussian_blur(image, kernel_size, sigma) if image.numel() == 0:
return image
shape = image.shape
if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = _FT.gaussian_blur(image, kernel_size, sigma)
if needs_unsquash:
output = output.view(shape)
return output
@torch.jit.unused @torch.jit.unused
...@@ -71,25 +87,7 @@ def gaussian_blur_image_pil( ...@@ -71,25 +87,7 @@ def gaussian_blur_image_pil(
def gaussian_blur_video( def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when return gaussian_blur_image_tensor(video, kernel_size, sigma)
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = gaussian_blur_image_tensor(video, kernel_size, sigma)
if needs_unsquash:
output = output.view(shape)
return output
def gaussian_blur( def gaussian_blur(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment