"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d9c25521bcbdbcaa6d2927ce04df0eeb59bafa99"
Unverified Commit d0e16b76 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow len 1 sequences for fill with PIL (#7928)

parent 439c5e34
...@@ -2581,9 +2581,6 @@ class TestCrop: ...@@ -2581,9 +2581,6 @@ class TestCrop:
# 2. the fill parameter only has an affect if we need padding # 2. the fill parameter only has an affect if we need padding
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE] kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]
if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1:
pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.")
if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)): if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.") pytest.skip("F.pad_mask doesn't support non-scalar fill.")
......
import collections.abc
import pytest import pytest
import torchvision.transforms.v2.functional as F import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors from torchvision import tv_tensors
...@@ -112,32 +110,6 @@ multi_crop_skips = [ ...@@ -112,32 +110,6 @@ multi_crop_skips = [
multi_crop_skips.append(skip_dispatch_tv_tensor) multi_crop_skips.append(skip_dispatch_tv_tensor)
def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]
def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False
if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False
return image_loader.num_channels > 1
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)
DISPATCHER_INFOS = [ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
...@@ -159,14 +131,6 @@ DISPATCHER_INFOS = [ ...@@ -159,14 +131,6 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition), xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"), xfail_jit_python_scalar_arg("padding"),
], ],
...@@ -181,7 +145,6 @@ DISPATCHER_INFOS = [ ...@@ -181,7 +145,6 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F._perspective_image_pil), pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[ test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"), xfail_jit_python_scalar_arg("fill"),
], ],
), ),
......
...@@ -264,11 +264,13 @@ def _parse_fill( ...@@ -264,11 +264,13 @@ def _parse_fill(
if isinstance(fill, (int, float)) and num_channels > 1: if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels) fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)): if isinstance(fill, (list, tuple)):
if len(fill) != num_channels: if len(fill) == 1:
fill = fill * num_channels
elif len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})" msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels)) raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill) fill = tuple(fill) # type: ignore[arg-type]
if img.mode != "F": if img.mode != "F":
if isinstance(fill, (list, tuple)): if isinstance(fill, (list, tuple)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment