transforms_v2_dispatcher_infos.py 12.6 KB
Newer Older
1
import collections.abc
2
3

import pytest
4
import torchvision.transforms.v2.functional as F
5
from torchvision import tv_tensors
6
from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
7
from transforms_v2_legacy_utils import InfoBase, TestMark
8
9
10
11

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class PILKernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__)
        self.kernel = kernel


class DispatcherInfo(InfoBase):
    _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}

    def __init__(
        self,
        dispatcher,
        *,
        # Dictionary of types that map to the kernel the dispatcher dispatches to.
        kernels,
        # If omitted, no PIL dispatch test will be performed.
        pil_kernel_info=None,
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.dispatcher = dispatcher
        self.kernels = kernels
        self.pil_kernel_info = pil_kernel_info

        kernel_infos = {}
47
        for tv_tensor_type, kernel in self.kernels.items():
48
49
50
            kernel_info = self._KERNEL_INFO_MAP.get(kernel)
            if not kernel_info:
                raise pytest.UsageError(
51
                    f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
52
                    f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
53
                )
54
            kernel_infos[tv_tensor_type] = kernel_info
55
        self.kernel_infos = kernel_infos
56

57
58
59
    def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
        for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
            kernel_info = self.kernel_infos.get(tv_tensor_type)
60
61
62
63
            if not kernel_info:
                raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")

            sample_inputs = kernel_info.sample_inputs_fn()
64
65
66

            if not filter_metadata:
                yield from sample_inputs
67
                return
68

69
70
71
            import itertools

            for args_kwargs in sample_inputs:
72
                if hasattr(tv_tensor_type, "__annotations__"):
73
                    for name in itertools.chain(
74
                        tv_tensor_type.__annotations__.keys(),
75
76
                        # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
                        #  per-dispatcher level. However, so far there is no option for that.
77
                        (f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
78
79
80
                    ):
                        if name in args_kwargs.kwargs:
                            del args_kwargs.kwargs[name]
81
82

                yield args_kwargs
83

84

85
def xfail_jit(reason, *, condition=None):
86
87
88
    return TestMark(
        ("TestDispatchers", "test_scripted_smoke"),
        pytest.mark.xfail(reason=reason),
89
        condition=condition,
90
91
    )

92

93
94
95
96
97
def xfail_jit_python_scalar_arg(name, *, reason=None):
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )
98
99


100
101
102
skip_dispatch_tv_tensor = TestMark(
    ("TestDispatchers", "test_dispatch_tv_tensor"),
    pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
103
104
)

105
106
107
108
109
multi_crop_skips = [
    TestMark(
        ("TestDispatchers", test_name),
        pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
    )
110
    for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
111
]
112
multi_crop_skips.append(skip_dispatch_tv_tensor)
113

114

115
116
117
118
119
120
121
def xfails_pil(reason, *, condition=None):
    return [
        TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
        for test_name in ["test_dispatch_pil", "test_pil_output_type"]
    ]


122
def fill_sequence_needs_broadcast(args_kwargs):
123
124
125
126
127
128
129
130
131
132
133
134
    (image_loader, *_), kwargs = args_kwargs
    try:
        fill = kwargs["fill"]
    except KeyError:
        return False

    if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
        return False

    return image_loader.num_channels > 1


135
136
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
    "PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
137
138
139
140
    condition=fill_sequence_needs_broadcast,
)


141
142
143
144
DISPATCHER_INFOS = [
    DispatcherInfo(
        F.resized_crop,
        kernels={
145
146
147
148
            tv_tensors.Image: F.resized_crop_image,
            tv_tensors.Video: F.resized_crop_video,
            tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
            tv_tensors.Mask: F.resized_crop_mask,
149
        },
150
        pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
151
152
153
154
    ),
    DispatcherInfo(
        F.pad,
        kernels={
155
156
157
158
            tv_tensors.Image: F.pad_image,
            tv_tensors.Video: F.pad_video,
            tv_tensors.BoundingBoxes: F.pad_bounding_boxes,
            tv_tensors.Mask: F.pad_mask,
159
        },
160
        pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
161
        test_marks=[
162
163
164
165
            *xfails_pil(
                reason=(
                    "PIL kernel doesn't support sequences of length 1 for argument `fill` and "
                    "`padding_mode='constant'`, if the number of color channels is larger."
166
167
168
                ),
                condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
                and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
169
            ),
170
171
            xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
            xfail_jit_python_scalar_arg("padding"),
172
        ],
173
    ),
174
175
176
    DispatcherInfo(
        F.perspective,
        kernels={
177
178
179
180
            tv_tensors.Image: F.perspective_image,
            tv_tensors.Video: F.perspective_video,
            tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
            tv_tensors.Mask: F.perspective_mask,
181
        },
182
        pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
183
        test_marks=[
184
185
            *xfails_pil_if_fill_sequence_needs_broadcast,
            xfail_jit_python_scalar_arg("fill"),
186
        ],
187
    ),
188
189
190
    DispatcherInfo(
        F.elastic,
        kernels={
191
192
193
194
            tv_tensors.Image: F.elastic_image,
            tv_tensors.Video: F.elastic_video,
            tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
            tv_tensors.Mask: F.elastic_mask,
195
        },
196
        pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
197
        test_marks=[xfail_jit_python_scalar_arg("fill")],
198
    ),
199
200
201
    DispatcherInfo(
        F.center_crop,
        kernels={
202
203
204
205
            tv_tensors.Image: F.center_crop_image,
            tv_tensors.Video: F.center_crop_video,
            tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
            tv_tensors.Mask: F.center_crop_mask,
206
        },
207
        pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
208
        test_marks=[
209
            xfail_jit_python_scalar_arg("output_size"),
210
        ],
211
212
213
214
    ),
    DispatcherInfo(
        F.gaussian_blur,
        kernels={
215
216
            tv_tensors.Image: F.gaussian_blur_image,
            tv_tensors.Video: F.gaussian_blur_video,
217
        },
218
        pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
219
        test_marks=[
220
221
            xfail_jit_python_scalar_arg("kernel_size"),
            xfail_jit_python_scalar_arg("sigma"),
222
        ],
223
224
225
226
    ),
    DispatcherInfo(
        F.equalize,
        kernels={
227
228
            tv_tensors.Image: F.equalize_image,
            tv_tensors.Video: F.equalize_video,
229
        },
230
        pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
231
232
233
234
    ),
    DispatcherInfo(
        F.invert,
        kernels={
235
236
            tv_tensors.Image: F.invert_image,
            tv_tensors.Video: F.invert_video,
237
        },
238
        pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
239
240
241
242
    ),
    DispatcherInfo(
        F.posterize,
        kernels={
243
244
            tv_tensors.Image: F.posterize_image,
            tv_tensors.Video: F.posterize_video,
245
        },
246
        pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
247
248
249
250
    ),
    DispatcherInfo(
        F.solarize,
        kernels={
251
252
            tv_tensors.Image: F.solarize_image,
            tv_tensors.Video: F.solarize_video,
253
        },
254
        pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
255
256
257
258
    ),
    DispatcherInfo(
        F.autocontrast,
        kernels={
259
260
            tv_tensors.Image: F.autocontrast_image,
            tv_tensors.Video: F.autocontrast_video,
261
        },
262
        pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
263
264
265
266
    ),
    DispatcherInfo(
        F.adjust_sharpness,
        kernels={
267
268
            tv_tensors.Image: F.adjust_sharpness_image,
            tv_tensors.Video: F.adjust_sharpness_video,
269
        },
270
        pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
271
272
273
274
    ),
    DispatcherInfo(
        F.erase,
        kernels={
275
276
            tv_tensors.Image: F.erase_image,
            tv_tensors.Video: F.erase_video,
277
        },
278
        pil_kernel_info=PILKernelInfo(F._erase_image_pil),
279
        test_marks=[
280
            skip_dispatch_tv_tensor,
281
        ],
282
    ),
283
284
285
    DispatcherInfo(
        F.adjust_contrast,
        kernels={
286
287
            tv_tensors.Image: F.adjust_contrast_image,
            tv_tensors.Video: F.adjust_contrast_video,
288
        },
289
        pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
290
291
292
293
    ),
    DispatcherInfo(
        F.adjust_gamma,
        kernels={
294
295
            tv_tensors.Image: F.adjust_gamma_image,
            tv_tensors.Video: F.adjust_gamma_video,
296
        },
297
        pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
298
299
300
301
    ),
    DispatcherInfo(
        F.adjust_hue,
        kernels={
302
303
            tv_tensors.Image: F.adjust_hue_image,
            tv_tensors.Video: F.adjust_hue_video,
304
        },
305
        pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
306
307
308
309
    ),
    DispatcherInfo(
        F.adjust_saturation,
        kernels={
310
311
            tv_tensors.Image: F.adjust_saturation_image,
            tv_tensors.Video: F.adjust_saturation_video,
312
        },
313
        pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
314
315
316
317
    ),
    DispatcherInfo(
        F.five_crop,
        kernels={
318
319
            tv_tensors.Image: F.five_crop_image,
            tv_tensors.Video: F.five_crop_video,
320
        },
321
        pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
322
        test_marks=[
323
            xfail_jit_python_scalar_arg("size"),
324
            *multi_crop_skips,
325
326
327
328
329
        ],
    ),
    DispatcherInfo(
        F.ten_crop,
        kernels={
330
331
            tv_tensors.Image: F.ten_crop_image,
            tv_tensors.Video: F.ten_crop_video,
332
        },
333
        test_marks=[
334
            xfail_jit_python_scalar_arg("size"),
335
            *multi_crop_skips,
336
        ],
337
        pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
338
339
340
341
    ),
    DispatcherInfo(
        F.normalize,
        kernels={
342
343
            tv_tensors.Image: F.normalize_image,
            tv_tensors.Video: F.normalize_video,
344
        },
345
        test_marks=[
346
347
            xfail_jit_python_scalar_arg("mean"),
            xfail_jit_python_scalar_arg("std"),
348
        ],
349
    ),
350
351
352
    DispatcherInfo(
        F.uniform_temporal_subsample,
        kernels={
353
            tv_tensors.Video: F.uniform_temporal_subsample_video,
354
355
        },
        test_marks=[
356
            skip_dispatch_tv_tensor,
357
358
        ],
    ),
359
    DispatcherInfo(
360
        F.clamp_bounding_boxes,
361
        kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
362
        test_marks=[
363
            skip_dispatch_tv_tensor,
364
365
366
        ],
    ),
    DispatcherInfo(
Nicolas Hug's avatar
Nicolas Hug committed
367
        F.convert_bounding_box_format,
368
        kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
369
        test_marks=[
370
            skip_dispatch_tv_tensor,
371
372
        ],
    ),
373
]