transforms_v2_dispatcher_infos.py 9.5 KB
Newer Older
1
import pytest
2
import torchvision.transforms.v2.functional as F
3
from torchvision import tv_tensors
4
from transforms_v2_kernel_infos import KERNEL_INFOS
5
from transforms_v2_legacy_utils import InfoBase, TestMark
6
7
8
9

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class PILKernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__)
        self.kernel = kernel


class DispatcherInfo(InfoBase):
    _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}

    def __init__(
        self,
        dispatcher,
        *,
        # Dictionary of types that map to the kernel the dispatcher dispatches to.
        kernels,
        # If omitted, no PIL dispatch test will be performed.
        pil_kernel_info=None,
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.dispatcher = dispatcher
        self.kernels = kernels
        self.pil_kernel_info = pil_kernel_info

        kernel_infos = {}
45
        for tv_tensor_type, kernel in self.kernels.items():
46
47
48
            kernel_info = self._KERNEL_INFO_MAP.get(kernel)
            if not kernel_info:
                raise pytest.UsageError(
49
                    f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
50
                    f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
51
                )
52
            kernel_infos[tv_tensor_type] = kernel_info
53
        self.kernel_infos = kernel_infos
54

55
56
57
    def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
        for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
            kernel_info = self.kernel_infos.get(tv_tensor_type)
58
59
60
61
            if not kernel_info:
                raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")

            sample_inputs = kernel_info.sample_inputs_fn()
62
63
64

            if not filter_metadata:
                yield from sample_inputs
65
                return
66

67
68
69
            import itertools

            for args_kwargs in sample_inputs:
70
                if hasattr(tv_tensor_type, "__annotations__"):
71
                    for name in itertools.chain(
72
                        tv_tensor_type.__annotations__.keys(),
73
74
                        # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
                        #  per-dispatcher level. However, so far there is no option for that.
75
                        (f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
76
77
78
                    ):
                        if name in args_kwargs.kwargs:
                            del args_kwargs.kwargs[name]
79
80

                yield args_kwargs
81

82

83
def xfail_jit(reason, *, condition=None):
84
85
86
    return TestMark(
        ("TestDispatchers", "test_scripted_smoke"),
        pytest.mark.xfail(reason=reason),
87
        condition=condition,
88
89
    )

90

91
92
93
94
95
def xfail_jit_python_scalar_arg(name, *, reason=None):
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )
96
97


98
99
100
skip_dispatch_tv_tensor = TestMark(
    ("TestDispatchers", "test_dispatch_tv_tensor"),
    pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
101
102
)

103
104
105
106
107
multi_crop_skips = [
    TestMark(
        ("TestDispatchers", test_name),
        pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
    )
108
    for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
109
]
110
multi_crop_skips.append(skip_dispatch_tv_tensor)
111

112

113
DISPATCHER_INFOS = [
114
115
116
    DispatcherInfo(
        F.perspective,
        kernels={
117
118
119
120
            tv_tensors.Image: F.perspective_image,
            tv_tensors.Video: F.perspective_video,
            tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
            tv_tensors.Mask: F.perspective_mask,
121
        },
122
        pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
123
        test_marks=[
124
            xfail_jit_python_scalar_arg("fill"),
125
        ],
126
    ),
127
128
129
    DispatcherInfo(
        F.elastic,
        kernels={
130
131
132
133
            tv_tensors.Image: F.elastic_image,
            tv_tensors.Video: F.elastic_video,
            tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
            tv_tensors.Mask: F.elastic_mask,
134
        },
135
        pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
136
        test_marks=[xfail_jit_python_scalar_arg("fill")],
137
    ),
138
139
140
    DispatcherInfo(
        F.center_crop,
        kernels={
141
142
143
144
            tv_tensors.Image: F.center_crop_image,
            tv_tensors.Video: F.center_crop_video,
            tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
            tv_tensors.Mask: F.center_crop_mask,
145
        },
146
        pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
147
        test_marks=[
148
            xfail_jit_python_scalar_arg("output_size"),
149
        ],
150
151
152
153
    ),
    DispatcherInfo(
        F.equalize,
        kernels={
154
155
            tv_tensors.Image: F.equalize_image,
            tv_tensors.Video: F.equalize_video,
156
        },
157
        pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
158
159
160
161
    ),
    DispatcherInfo(
        F.invert,
        kernels={
162
163
            tv_tensors.Image: F.invert_image,
            tv_tensors.Video: F.invert_video,
164
        },
165
        pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
166
167
168
169
    ),
    DispatcherInfo(
        F.posterize,
        kernels={
170
171
            tv_tensors.Image: F.posterize_image,
            tv_tensors.Video: F.posterize_video,
172
        },
173
        pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
174
175
176
177
    ),
    DispatcherInfo(
        F.solarize,
        kernels={
178
179
            tv_tensors.Image: F.solarize_image,
            tv_tensors.Video: F.solarize_video,
180
        },
181
        pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
182
183
184
185
    ),
    DispatcherInfo(
        F.autocontrast,
        kernels={
186
187
            tv_tensors.Image: F.autocontrast_image,
            tv_tensors.Video: F.autocontrast_video,
188
        },
189
        pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
190
191
192
193
    ),
    DispatcherInfo(
        F.adjust_sharpness,
        kernels={
194
195
            tv_tensors.Image: F.adjust_sharpness_image,
            tv_tensors.Video: F.adjust_sharpness_video,
196
        },
197
        pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
198
    ),
199
200
201
    DispatcherInfo(
        F.adjust_contrast,
        kernels={
202
203
            tv_tensors.Image: F.adjust_contrast_image,
            tv_tensors.Video: F.adjust_contrast_video,
204
        },
205
        pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
206
207
208
209
    ),
    DispatcherInfo(
        F.adjust_gamma,
        kernels={
210
211
            tv_tensors.Image: F.adjust_gamma_image,
            tv_tensors.Video: F.adjust_gamma_video,
212
        },
213
        pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
214
215
216
217
    ),
    DispatcherInfo(
        F.adjust_hue,
        kernels={
218
219
            tv_tensors.Image: F.adjust_hue_image,
            tv_tensors.Video: F.adjust_hue_video,
220
        },
221
        pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
222
223
224
225
    ),
    DispatcherInfo(
        F.adjust_saturation,
        kernels={
226
227
            tv_tensors.Image: F.adjust_saturation_image,
            tv_tensors.Video: F.adjust_saturation_video,
228
        },
229
        pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
230
231
232
233
    ),
    DispatcherInfo(
        F.five_crop,
        kernels={
234
235
            tv_tensors.Image: F.five_crop_image,
            tv_tensors.Video: F.five_crop_video,
236
        },
237
        pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
238
        test_marks=[
239
            xfail_jit_python_scalar_arg("size"),
240
            *multi_crop_skips,
241
242
243
244
245
        ],
    ),
    DispatcherInfo(
        F.ten_crop,
        kernels={
246
247
            tv_tensors.Image: F.ten_crop_image,
            tv_tensors.Video: F.ten_crop_video,
248
        },
249
        test_marks=[
250
            xfail_jit_python_scalar_arg("size"),
251
            *multi_crop_skips,
252
        ],
253
        pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
254
255
256
257
    ),
    DispatcherInfo(
        F.normalize,
        kernels={
258
259
            tv_tensors.Image: F.normalize_image,
            tv_tensors.Video: F.normalize_video,
260
        },
261
        test_marks=[
262
263
            xfail_jit_python_scalar_arg("mean"),
            xfail_jit_python_scalar_arg("std"),
264
        ],
265
    ),
266
267
268
    DispatcherInfo(
        F.uniform_temporal_subsample,
        kernels={
269
            tv_tensors.Video: F.uniform_temporal_subsample_video,
270
271
        },
        test_marks=[
272
            skip_dispatch_tv_tensor,
273
274
        ],
    ),
275
    DispatcherInfo(
276
        F.clamp_bounding_boxes,
277
        kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
278
        test_marks=[
279
            skip_dispatch_tv_tensor,
280
281
        ],
    ),
282
]