transforms_v2_dispatcher_infos.py 15.7 KB
Newer Older
1
import collections.abc
2
3

import pytest
4
import torchvision.transforms.v2.functional as F
5
from common_utils import InfoBase, TestMark
6
from torchvision import datapoints
7
from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
8
9
10
11

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class PILKernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__)
        self.kernel = kernel


class DispatcherInfo(InfoBase):
    _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}

    def __init__(
        self,
        dispatcher,
        *,
        # Dictionary of types that map to the kernel the dispatcher dispatches to.
        kernels,
        # If omitted, no PIL dispatch test will be performed.
        pil_kernel_info=None,
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.dispatcher = dispatcher
        self.kernels = kernels
        self.pil_kernel_info = pil_kernel_info

        kernel_infos = {}
47
        for datapoint_type, kernel in self.kernels.items():
48
49
50
            kernel_info = self._KERNEL_INFO_MAP.get(kernel)
            if not kernel_info:
                raise pytest.UsageError(
51
                    f"Can't register {kernel.__name__} for type {datapoint_type} since there is no `KernelInfo` for it. "
52
                    f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
53
                )
54
            kernel_infos[datapoint_type] = kernel_info
55
        self.kernel_infos = kernel_infos
56

57
58
59
    def sample_inputs(self, *datapoint_types, filter_metadata=True):
        for datapoint_type in datapoint_types or self.kernel_infos.keys():
            kernel_info = self.kernel_infos.get(datapoint_type)
60
61
62
63
            if not kernel_info:
                raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")

            sample_inputs = kernel_info.sample_inputs_fn()
64
65
66

            if not filter_metadata:
                yield from sample_inputs
67
                return
68

69
70
71
72
73
74
75
76
77
78
79
80
81
            import itertools

            for args_kwargs in sample_inputs:
                for name in itertools.chain(
                    datapoint_type.__annotations__.keys(),
                    # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
                    #  per-dispatcher level. However, so far there is no option for that.
                    (f"old_{name}" for name in datapoint_type.__annotations__.keys()),
                ):
                    if name in args_kwargs.kwargs:
                        del args_kwargs.kwargs[name]

                yield args_kwargs
82

83

84
def xfail_jit(reason, *, condition=None):
85
86
87
    return TestMark(
        ("TestDispatchers", "test_scripted_smoke"),
        pytest.mark.xfail(reason=reason),
88
        condition=condition,
89
90
    )

91

92
93
94
95
96
def xfail_jit_python_scalar_arg(name, *, reason=None):
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )
97
98


99
100
101
skip_dispatch_datapoint = TestMark(
    ("TestDispatchers", "test_dispatch_datapoint"),
    pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
102
103
)

104
105
106
107
108
109
110
111
112
multi_crop_skips = [
    TestMark(
        ("TestDispatchers", test_name),
        pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
    )
    for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)

113

114
115
116
117
118
119
120
def xfails_pil(reason, *, condition=None):
    return [
        TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
        for test_name in ["test_dispatch_pil", "test_pil_output_type"]
    ]


121
def fill_sequence_needs_broadcast(args_kwargs):
122
123
124
125
126
127
128
129
130
131
132
133
    (image_loader, *_), kwargs = args_kwargs
    try:
        fill = kwargs["fill"]
    except KeyError:
        return False

    if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
        return False

    return image_loader.num_channels > 1


134
135
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
    "PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
136
137
138
139
    condition=fill_sequence_needs_broadcast,
)


140
141
142
143
DISPATCHER_INFOS = [
    DispatcherInfo(
        F.horizontal_flip,
        kernels={
144
145
146
147
            datapoints.Image: F.horizontal_flip_image_tensor,
            datapoints.Video: F.horizontal_flip_video,
            datapoints.BoundingBox: F.horizontal_flip_bounding_box,
            datapoints.Mask: F.horizontal_flip_mask,
148
        },
149
        pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
150
151
152
153
    ),
    DispatcherInfo(
        F.resize,
        kernels={
154
155
156
157
            datapoints.Image: F.resize_image_tensor,
            datapoints.Video: F.resize_video,
            datapoints.BoundingBox: F.resize_bounding_box,
            datapoints.Mask: F.resize_mask,
158
        },
159
        pil_kernel_info=PILKernelInfo(F.resize_image_pil),
160
        test_marks=[
161
            xfail_jit_python_scalar_arg("size"),
162
        ],
163
164
165
166
    ),
    DispatcherInfo(
        F.affine,
        kernels={
167
168
169
170
            datapoints.Image: F.affine_image_tensor,
            datapoints.Video: F.affine_video,
            datapoints.BoundingBox: F.affine_bounding_box,
            datapoints.Mask: F.affine_mask,
171
        },
172
        pil_kernel_info=PILKernelInfo(F.affine_image_pil),
173
        test_marks=[
174
            *xfails_pil_if_fill_sequence_needs_broadcast,
175
            xfail_jit_python_scalar_arg("shear"),
176
            xfail_jit_python_scalar_arg("fill"),
177
        ],
178
179
180
181
    ),
    DispatcherInfo(
        F.vertical_flip,
        kernels={
182
183
184
185
            datapoints.Image: F.vertical_flip_image_tensor,
            datapoints.Video: F.vertical_flip_video,
            datapoints.BoundingBox: F.vertical_flip_bounding_box,
            datapoints.Mask: F.vertical_flip_mask,
186
        },
187
        pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
188
189
190
191
    ),
    DispatcherInfo(
        F.rotate,
        kernels={
192
193
194
195
            datapoints.Image: F.rotate_image_tensor,
            datapoints.Video: F.rotate_video,
            datapoints.BoundingBox: F.rotate_bounding_box,
            datapoints.Mask: F.rotate_mask,
196
        },
197
        pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
198
        test_marks=[
199
200
            xfail_jit_python_scalar_arg("fill"),
            *xfails_pil_if_fill_sequence_needs_broadcast,
201
        ],
202
203
204
205
    ),
    DispatcherInfo(
        F.crop,
        kernels={
206
207
208
209
            datapoints.Image: F.crop_image_tensor,
            datapoints.Video: F.crop_video,
            datapoints.BoundingBox: F.crop_bounding_box,
            datapoints.Mask: F.crop_mask,
210
        },
211
        pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
212
213
214
215
    ),
    DispatcherInfo(
        F.resized_crop,
        kernels={
216
217
218
219
            datapoints.Image: F.resized_crop_image_tensor,
            datapoints.Video: F.resized_crop_video,
            datapoints.BoundingBox: F.resized_crop_bounding_box,
            datapoints.Mask: F.resized_crop_mask,
220
        },
221
        pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
222
223
224
225
    ),
    DispatcherInfo(
        F.pad,
        kernels={
226
227
228
229
            datapoints.Image: F.pad_image_tensor,
            datapoints.Video: F.pad_video,
            datapoints.BoundingBox: F.pad_bounding_box,
            datapoints.Mask: F.pad_mask,
230
        },
231
        pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
232
        test_marks=[
233
234
235
236
            *xfails_pil(
                reason=(
                    "PIL kernel doesn't support sequences of length 1 for argument `fill` and "
                    "`padding_mode='constant'`, if the number of color channels is larger."
237
238
239
                ),
                condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
                and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
240
            ),
241
242
            xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
            xfail_jit_python_scalar_arg("padding"),
243
        ],
244
    ),
245
246
247
    DispatcherInfo(
        F.perspective,
        kernels={
248
249
250
251
            datapoints.Image: F.perspective_image_tensor,
            datapoints.Video: F.perspective_video,
            datapoints.BoundingBox: F.perspective_bounding_box,
            datapoints.Mask: F.perspective_mask,
252
        },
253
        pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
254
        test_marks=[
255
256
            *xfails_pil_if_fill_sequence_needs_broadcast,
            xfail_jit_python_scalar_arg("fill"),
257
        ],
258
    ),
259
260
261
    DispatcherInfo(
        F.elastic,
        kernels={
262
263
264
265
            datapoints.Image: F.elastic_image_tensor,
            datapoints.Video: F.elastic_video,
            datapoints.BoundingBox: F.elastic_bounding_box,
            datapoints.Mask: F.elastic_mask,
266
        },
267
        pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
268
        test_marks=[xfail_jit_python_scalar_arg("fill")],
269
    ),
270
271
272
    DispatcherInfo(
        F.center_crop,
        kernels={
273
274
275
276
            datapoints.Image: F.center_crop_image_tensor,
            datapoints.Video: F.center_crop_video,
            datapoints.BoundingBox: F.center_crop_bounding_box,
            datapoints.Mask: F.center_crop_mask,
277
        },
278
        pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
279
        test_marks=[
280
            xfail_jit_python_scalar_arg("output_size"),
281
        ],
282
283
284
285
    ),
    DispatcherInfo(
        F.gaussian_blur,
        kernels={
286
287
            datapoints.Image: F.gaussian_blur_image_tensor,
            datapoints.Video: F.gaussian_blur_video,
288
        },
289
        pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
290
        test_marks=[
291
292
            xfail_jit_python_scalar_arg("kernel_size"),
            xfail_jit_python_scalar_arg("sigma"),
293
        ],
294
295
296
297
    ),
    DispatcherInfo(
        F.equalize,
        kernels={
298
299
            datapoints.Image: F.equalize_image_tensor,
            datapoints.Video: F.equalize_video,
300
        },
301
        pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
302
303
304
305
    ),
    DispatcherInfo(
        F.invert,
        kernels={
306
307
            datapoints.Image: F.invert_image_tensor,
            datapoints.Video: F.invert_video,
308
        },
309
        pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"),
310
311
312
313
    ),
    DispatcherInfo(
        F.posterize,
        kernels={
314
315
            datapoints.Image: F.posterize_image_tensor,
            datapoints.Video: F.posterize_video,
316
        },
317
        pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"),
318
319
320
321
    ),
    DispatcherInfo(
        F.solarize,
        kernels={
322
323
            datapoints.Image: F.solarize_image_tensor,
            datapoints.Video: F.solarize_video,
324
        },
325
        pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"),
326
327
328
329
    ),
    DispatcherInfo(
        F.autocontrast,
        kernels={
330
331
            datapoints.Image: F.autocontrast_image_tensor,
            datapoints.Video: F.autocontrast_video,
332
        },
333
        pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
334
335
336
337
    ),
    DispatcherInfo(
        F.adjust_sharpness,
        kernels={
338
339
            datapoints.Image: F.adjust_sharpness_image_tensor,
            datapoints.Video: F.adjust_sharpness_video,
340
        },
341
        pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
342
343
344
345
    ),
    DispatcherInfo(
        F.erase,
        kernels={
346
347
            datapoints.Image: F.erase_image_tensor,
            datapoints.Video: F.erase_video,
348
        },
349
        pil_kernel_info=PILKernelInfo(F.erase_image_pil),
350
        test_marks=[
351
            skip_dispatch_datapoint,
352
        ],
353
    ),
354
355
356
    DispatcherInfo(
        F.adjust_brightness,
        kernels={
357
358
            datapoints.Image: F.adjust_brightness_image_tensor,
            datapoints.Video: F.adjust_brightness_video,
359
        },
360
        pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
361
362
363
364
    ),
    DispatcherInfo(
        F.adjust_contrast,
        kernels={
365
366
            datapoints.Image: F.adjust_contrast_image_tensor,
            datapoints.Video: F.adjust_contrast_video,
367
        },
368
        pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
369
370
371
372
    ),
    DispatcherInfo(
        F.adjust_gamma,
        kernels={
373
374
            datapoints.Image: F.adjust_gamma_image_tensor,
            datapoints.Video: F.adjust_gamma_video,
375
        },
376
        pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
377
378
379
380
    ),
    DispatcherInfo(
        F.adjust_hue,
        kernels={
381
382
            datapoints.Image: F.adjust_hue_image_tensor,
            datapoints.Video: F.adjust_hue_video,
383
        },
384
        pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
385
386
387
388
    ),
    DispatcherInfo(
        F.adjust_saturation,
        kernels={
389
390
            datapoints.Image: F.adjust_saturation_image_tensor,
            datapoints.Video: F.adjust_saturation_video,
391
        },
392
        pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
393
394
395
396
    ),
    DispatcherInfo(
        F.five_crop,
        kernels={
397
398
            datapoints.Image: F.five_crop_image_tensor,
            datapoints.Video: F.five_crop_video,
399
        },
400
        pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
401
        test_marks=[
402
            xfail_jit_python_scalar_arg("size"),
403
            *multi_crop_skips,
404
405
406
407
408
        ],
    ),
    DispatcherInfo(
        F.ten_crop,
        kernels={
409
410
            datapoints.Image: F.ten_crop_image_tensor,
            datapoints.Video: F.ten_crop_video,
411
        },
412
        test_marks=[
413
            xfail_jit_python_scalar_arg("size"),
414
            *multi_crop_skips,
415
        ],
416
        pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
417
418
419
420
    ),
    DispatcherInfo(
        F.normalize,
        kernels={
421
422
            datapoints.Image: F.normalize_image_tensor,
            datapoints.Video: F.normalize_video,
423
        },
424
        test_marks=[
425
426
            xfail_jit_python_scalar_arg("mean"),
            xfail_jit_python_scalar_arg("std"),
427
        ],
428
    ),
429
430
431
    DispatcherInfo(
        F.convert_dtype,
        kernels={
432
433
            datapoints.Image: F.convert_dtype_image_tensor,
            datapoints.Video: F.convert_dtype_video,
434
435
        },
        test_marks=[
436
            skip_dispatch_datapoint,
437
438
        ],
    ),
439
440
441
    DispatcherInfo(
        F.uniform_temporal_subsample,
        kernels={
442
            datapoints.Video: F.uniform_temporal_subsample_video,
443
444
        },
        test_marks=[
445
            skip_dispatch_datapoint,
446
447
        ],
    ),
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    DispatcherInfo(
        F.clamp_bounding_box,
        kernels={datapoints.BoundingBox: F.clamp_bounding_box},
        test_marks=[
            skip_dispatch_datapoint,
        ],
    ),
    DispatcherInfo(
        F.convert_format_bounding_box,
        kernels={datapoints.BoundingBox: F.convert_format_bounding_box},
        test_marks=[
            skip_dispatch_datapoint,
        ],
    ),
462
]