Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4a99bae8
Unverified
Commit
4a99bae8
authored
Oct 05, 2022
by
Philip Meier
Committed by
GitHub
Oct 05, 2022
Browse files
add dispatch tests for prototype transform dispatchers (#6631)
parent
0e006a9f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
189 additions
and
10 deletions
+189
-10
test/prototype_transforms_dispatcher_infos.py
test/prototype_transforms_dispatcher_infos.py
+113
-9
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+1
-1
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+75
-0
No files found.
test/prototype_transforms_dispatcher_infos.py
View file @
4a99bae8
import
collections.abc
import
dataclasses
import
dataclasses
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Sequence
,
Type
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
import
pytest
import
pytest
import
torchvision.prototype.transforms.functional
as
F
import
torchvision.prototype.transforms.functional
as
F
from
prototype_transforms_kernel_infos
import
KERNEL_INFOS
,
Skip
from
prototype_common_utils
import
BoundingBoxLoader
from
prototype_transforms_kernel_infos
import
KERNEL_INFOS
,
KernelInfo
,
Skip
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
__all__
=
[
"DispatcherInfo"
,
"DISPATCHER_INFOS"
]
__all__
=
[
"DispatcherInfo"
,
"DISPATCHER_INFOS"
]
KERNEL_SAMPLE_INPUTS_FN_MAP
=
{
info
.
kernel
:
info
.
sample_inputs_fn
for
info
in
KERNEL_INFOS
}
KERNEL_INFO_MAP
=
{
info
.
kernel
:
info
for
info
in
KERNEL_INFOS
}
@
dataclasses
.
dataclass
class
PILKernelInfo
:
kernel
:
Callable
kernel_name
:
str
=
dataclasses
.
field
(
default
=
None
)
def
__post_init__
(
self
):
self
.
kernel_name
=
self
.
kernel_name
or
self
.
kernel
.
__name__
def
skip_python_scalar_arg_jit
(
name
,
*
,
reason
=
"Python scalar int or float is not supported when scripting"
):
def
skip_python_scalar_arg_jit
(
name
,
*
,
reason
=
"Python scalar int or float is not supported when scripting"
):
...
@@ -28,21 +40,35 @@ def skip_integer_size_jit(name="size"):
...
@@ -28,21 +40,35 @@ def skip_integer_size_jit(name="size"):
class
DispatcherInfo
:
class
DispatcherInfo
:
dispatcher
:
Callable
dispatcher
:
Callable
kernels
:
Dict
[
Type
,
Callable
]
kernels
:
Dict
[
Type
,
Callable
]
kernel_infos
:
Dict
[
Type
,
KernelInfo
]
=
dataclasses
.
field
(
default
=
None
)
pil_kernel_info
:
Optional
[
PILKernelInfo
]
=
None
method_name
:
str
=
dataclasses
.
field
(
default
=
None
)
skips
:
Sequence
[
Skip
]
=
dataclasses
.
field
(
default_factory
=
list
)
skips
:
Sequence
[
Skip
]
=
dataclasses
.
field
(
default_factory
=
list
)
_skips_map
:
Dict
[
str
,
List
[
Skip
]]
=
dataclasses
.
field
(
default
=
None
,
init
=
False
)
_skips_map
:
Dict
[
str
,
List
[
Skip
]]
=
dataclasses
.
field
(
default
=
None
,
init
=
False
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
kernel_infos
=
{
feature_type
:
KERNEL_INFO_MAP
[
kernel
]
for
feature_type
,
kernel
in
self
.
kernels
.
items
()}
self
.
method_name
=
self
.
method_name
or
self
.
dispatcher
.
__name__
skips_map
=
defaultdict
(
list
)
skips_map
=
defaultdict
(
list
)
for
skip
in
self
.
skips
:
for
skip
in
self
.
skips
:
skips_map
[
skip
.
test_name
].
append
(
skip
)
skips_map
[
skip
.
test_name
].
append
(
skip
)
self
.
_skips_map
=
dict
(
skips_map
)
self
.
_skips_map
=
dict
(
skips_map
)
def
sample_inputs
(
self
,
*
types
):
def
sample_inputs
(
self
,
*
feature_types
,
filter_metadata
=
True
):
for
type
in
types
or
self
.
kernels
.
keys
():
for
feature_type
in
feature_types
or
self
.
kernels
.
keys
():
if
type
not
in
self
.
kernels
:
if
feature_type
not
in
self
.
kernels
:
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
feature_type
.
__name__
}
"
)
sample_inputs
=
self
.
kernel_infos
[
feature_type
].
sample_inputs_fn
()
if
not
filter_metadata
:
yield
from
sample_inputs
else
:
for
args_kwargs
in
sample_inputs
:
for
attribute
in
feature_type
.
__annotations__
.
keys
():
if
attribute
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
attribute
]
yield
from
KERNEL_SAMPLE_INPUTS_FN_MAP
[
self
.
kernels
[
type
]]()
yield
args_kwargs
def
maybe_skip
(
self
,
*
,
test_name
,
args_kwargs
,
device
):
def
maybe_skip
(
self
,
*
,
test_name
,
args_kwargs
,
device
):
skips
=
self
.
_skips_map
.
get
(
test_name
)
skips
=
self
.
_skips_map
.
get
(
test_name
)
...
@@ -54,6 +80,31 @@ class DispatcherInfo:
...
@@ -54,6 +80,31 @@ class DispatcherInfo:
pytest
.
skip
(
skip
.
reason
)
pytest
.
skip
(
skip
.
reason
)
def
fill_sequence_needs_broadcast
(
args_kwargs
,
device
):
(
image_loader
,
*
_
),
kwargs
=
args_kwargs
try
:
fill
=
kwargs
[
"fill"
]
except
KeyError
:
return
False
if
not
isinstance
(
fill
,
collections
.
abc
.
Sequence
)
or
len
(
fill
)
>
1
:
return
False
return
image_loader
.
num_channels
>
1
skip_dispatch_pil_if_fill_sequence_needs_broadcast
=
Skip
(
"test_dispatch_pil"
,
condition
=
fill_sequence_needs_broadcast
,
reason
=
"PIL kernel doesn't support sequences of length 1 if the number of channels is larger."
,
)
skip_dispatch_feature
=
Skip
(
"test_dispatch_feature"
,
reason
=
"Dispatcher doesn't support arbitrary feature dispatch."
,
)
DISPATCHER_INFOS
=
[
DISPATCHER_INFOS
=
[
DispatcherInfo
(
DispatcherInfo
(
F
.
horizontal_flip
,
F
.
horizontal_flip
,
...
@@ -62,6 +113,7 @@ DISPATCHER_INFOS = [
...
@@ -62,6 +113,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
horizontal_flip_bounding_box
,
features
.
BoundingBox
:
F
.
horizontal_flip_bounding_box
,
features
.
Mask
:
F
.
horizontal_flip_mask
,
features
.
Mask
:
F
.
horizontal_flip_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
horizontal_flip_image_pil
,
kernel_name
=
"horizontal_flip_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
resize
,
F
.
resize
,
...
@@ -70,6 +122,7 @@ DISPATCHER_INFOS = [
...
@@ -70,6 +122,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
resize_bounding_box
,
features
.
BoundingBox
:
F
.
resize_bounding_box
,
features
.
Mask
:
F
.
resize_mask
,
features
.
Mask
:
F
.
resize_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
resize_image_pil
),
skips
=
[
skips
=
[
skip_integer_size_jit
(),
skip_integer_size_jit
(),
],
],
...
@@ -81,7 +134,11 @@ DISPATCHER_INFOS = [
...
@@ -81,7 +134,11 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
affine_bounding_box
,
features
.
BoundingBox
:
F
.
affine_bounding_box
,
features
.
Mask
:
F
.
affine_mask
,
features
.
Mask
:
F
.
affine_mask
,
},
},
skips
=
[
skip_python_scalar_arg_jit
(
"shear"
,
reason
=
"Scalar shear is not supported by JIT"
)],
pil_kernel_info
=
PILKernelInfo
(
F
.
affine_image_pil
),
skips
=
[
skip_dispatch_pil_if_fill_sequence_needs_broadcast
,
skip_python_scalar_arg_jit
(
"shear"
,
reason
=
"Scalar shear is not supported by JIT"
),
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
vertical_flip
,
F
.
vertical_flip
,
...
@@ -90,6 +147,7 @@ DISPATCHER_INFOS = [
...
@@ -90,6 +147,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
vertical_flip_bounding_box
,
features
.
BoundingBox
:
F
.
vertical_flip_bounding_box
,
features
.
Mask
:
F
.
vertical_flip_mask
,
features
.
Mask
:
F
.
vertical_flip_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
vertical_flip_image_pil
,
kernel_name
=
"vertical_flip_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
rotate
,
F
.
rotate
,
...
@@ -98,6 +156,7 @@ DISPATCHER_INFOS = [
...
@@ -98,6 +156,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
rotate_bounding_box
,
features
.
BoundingBox
:
F
.
rotate_bounding_box
,
features
.
Mask
:
F
.
rotate_mask
,
features
.
Mask
:
F
.
rotate_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
rotate_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
crop
,
F
.
crop
,
...
@@ -106,6 +165,17 @@ DISPATCHER_INFOS = [
...
@@ -106,6 +165,17 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
crop_bounding_box
,
features
.
BoundingBox
:
F
.
crop_bounding_box
,
features
.
Mask
:
F
.
crop_mask
,
features
.
Mask
:
F
.
crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
skips
=
[
Skip
(
"test_dispatch_feature"
,
condition
=
lambda
args_kwargs
,
device
:
isinstance
(
args_kwargs
.
args
[
0
],
BoundingBoxLoader
),
reason
=
(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
resized_crop
,
F
.
resized_crop
,
...
@@ -114,6 +184,7 @@ DISPATCHER_INFOS = [
...
@@ -114,6 +184,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
resized_crop_bounding_box
,
features
.
BoundingBox
:
F
.
resized_crop_bounding_box
,
features
.
Mask
:
F
.
resized_crop_mask
,
features
.
Mask
:
F
.
resized_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
resized_crop_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
pad
,
F
.
pad
,
...
@@ -122,6 +193,10 @@ DISPATCHER_INFOS = [
...
@@ -122,6 +193,10 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
pad_bounding_box
,
features
.
BoundingBox
:
F
.
pad_bounding_box
,
features
.
Mask
:
F
.
pad_mask
,
features
.
Mask
:
F
.
pad_mask
,
},
},
skips
=
[
skip_dispatch_pil_if_fill_sequence_needs_broadcast
,
],
pil_kernel_info
=
PILKernelInfo
(
F
.
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
perspective
,
F
.
perspective
,
...
@@ -130,6 +205,10 @@ DISPATCHER_INFOS = [
...
@@ -130,6 +205,10 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
perspective_bounding_box
,
features
.
BoundingBox
:
F
.
perspective_bounding_box
,
features
.
Mask
:
F
.
perspective_mask
,
features
.
Mask
:
F
.
perspective_mask
,
},
},
skips
=
[
skip_dispatch_pil_if_fill_sequence_needs_broadcast
,
],
pil_kernel_info
=
PILKernelInfo
(
F
.
perspective_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
elastic
,
F
.
elastic
,
...
@@ -138,6 +217,7 @@ DISPATCHER_INFOS = [
...
@@ -138,6 +217,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
elastic_bounding_box
,
features
.
BoundingBox
:
F
.
elastic_bounding_box
,
features
.
Mask
:
F
.
elastic_mask
,
features
.
Mask
:
F
.
elastic_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
elastic_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
center_crop
,
F
.
center_crop
,
...
@@ -146,6 +226,7 @@ DISPATCHER_INFOS = [
...
@@ -146,6 +226,7 @@ DISPATCHER_INFOS = [
features
.
BoundingBox
:
F
.
center_crop_bounding_box
,
features
.
BoundingBox
:
F
.
center_crop_bounding_box
,
features
.
Mask
:
F
.
center_crop_mask
,
features
.
Mask
:
F
.
center_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
center_crop_image_pil
),
skips
=
[
skips
=
[
skip_integer_size_jit
(
"output_size"
),
skip_integer_size_jit
(
"output_size"
),
],
],
...
@@ -155,6 +236,7 @@ DISPATCHER_INFOS = [
...
@@ -155,6 +236,7 @@ DISPATCHER_INFOS = [
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
gaussian_blur_image_tensor
,
features
.
Image
:
F
.
gaussian_blur_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
gaussian_blur_image_pil
),
skips
=
[
skips
=
[
skip_python_scalar_arg_jit
(
"kernel_size"
),
skip_python_scalar_arg_jit
(
"kernel_size"
),
skip_python_scalar_arg_jit
(
"sigma"
),
skip_python_scalar_arg_jit
(
"sigma"
),
...
@@ -165,80 +247,97 @@ DISPATCHER_INFOS = [
...
@@ -165,80 +247,97 @@ DISPATCHER_INFOS = [
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
equalize_image_tensor
,
features
.
Image
:
F
.
equalize_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
invert
,
F
.
invert
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
invert_image_tensor
,
features
.
Image
:
F
.
invert_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
posterize
,
F
.
posterize
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
posterize_image_tensor
,
features
.
Image
:
F
.
posterize_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
solarize
,
F
.
solarize
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
solarize_image_tensor
,
features
.
Image
:
F
.
solarize_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
autocontrast
,
F
.
autocontrast
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
autocontrast_image_tensor
,
features
.
Image
:
F
.
autocontrast_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_sharpness
,
F
.
adjust_sharpness
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_sharpness_image_tensor
,
features
.
Image
:
F
.
adjust_sharpness_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
erase
,
F
.
erase
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
erase_image_tensor
,
features
.
Image
:
F
.
erase_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
skips
=
[
skip_dispatch_feature
,
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_brightness
,
F
.
adjust_brightness
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_brightness_image_tensor
,
features
.
Image
:
F
.
adjust_brightness_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_brightness_image_pil
,
kernel_name
=
"adjust_brightness_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_contrast
,
F
.
adjust_contrast
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_contrast_image_tensor
,
features
.
Image
:
F
.
adjust_contrast_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_gamma
,
F
.
adjust_gamma
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_gamma_image_tensor
,
features
.
Image
:
F
.
adjust_gamma_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_hue
,
F
.
adjust_hue
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_hue_image_tensor
,
features
.
Image
:
F
.
adjust_hue_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_saturation
,
F
.
adjust_saturation
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
adjust_saturation_image_tensor
,
features
.
Image
:
F
.
adjust_saturation_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
five_crop
,
F
.
five_crop
,
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
five_crop_image_tensor
,
features
.
Image
:
F
.
five_crop_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
skips
=
[
skips
=
[
skip_integer_size_jit
(),
skip_integer_size_jit
(),
skip_dispatch_feature
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -246,8 +345,10 @@ DISPATCHER_INFOS = [
...
@@ -246,8 +345,10 @@ DISPATCHER_INFOS = [
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
ten_crop_image_tensor
,
features
.
Image
:
F
.
ten_crop_image_tensor
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
skips
=
[
skips
=
[
skip_integer_size_jit
(),
skip_integer_size_jit
(),
skip_dispatch_feature
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -255,5 +356,8 @@ DISPATCHER_INFOS = [
...
@@ -255,5 +356,8 @@ DISPATCHER_INFOS = [
kernels
=
{
kernels
=
{
features
.
Image
:
F
.
normalize_image_tensor
,
features
.
Image
:
F
.
normalize_image_tensor
,
},
},
skips
=
[
skip_dispatch_feature
,
],
),
),
]
]
test/prototype_transforms_kernel_infos.py
View file @
4a99bae8
...
@@ -33,7 +33,7 @@ class KernelInfo:
...
@@ -33,7 +33,7 @@ class KernelInfo:
sample_inputs_fn
:
Callable
[[],
Iterable
[
ArgsKwargs
]]
sample_inputs_fn
:
Callable
[[],
Iterable
[
ArgsKwargs
]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name
:
Optional
[
str
]
=
None
kernel_name
:
str
=
dataclasses
.
field
(
default
=
None
)
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
...
...
test/test_prototype_transforms_functional.py
View file @
4a99bae8
...
@@ -174,6 +174,18 @@ class TestKernels:
...
@@ -174,6 +174,18 @@ class TestKernels:
assert_close
(
actual
,
expected
,
check_dtype
=
False
,
**
info
.
closeness_kwargs
)
assert_close
(
actual
,
expected
,
check_dtype
=
False
,
**
info
.
closeness_kwargs
)
@
pytest
.
fixture
def
spy_on
(
mocker
):
def
make_spy
(
fn
,
*
,
module
=
None
,
name
=
None
):
# TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
module
=
module
or
fn
.
__module__
name
=
name
or
fn
.
__name__
spy
=
mocker
.
patch
(
f
"
{
module
}
.
{
name
}
"
,
wraps
=
fn
)
return
spy
return
make_spy
class
TestDispatchers
:
class
TestDispatchers
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"info"
,
"args_kwargs"
),
(
"info"
,
"args_kwargs"
),
...
@@ -211,6 +223,69 @@ class TestDispatchers:
...
@@ -211,6 +223,69 @@ class TestDispatchers:
def
test_scriptable
(
self
,
dispatcher
):
def
test_scriptable
(
self
,
dispatcher
):
script
(
dispatcher
)
script
(
dispatcher
)
@
pytest
.
mark
.
parametrize
(
(
"info"
,
"args_kwargs"
),
[
pytest
.
param
(
info
,
args_kwargs
,
id
=
f
"
{
info
.
dispatcher
.
__name__
}
-
{
idx
}
"
)
for
info
in
DISPATCHER_INFOS
for
idx
,
args_kwargs
in
enumerate
(
info
.
sample_inputs
(
features
.
Image
))
if
features
.
Image
in
info
.
kernels
],
)
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_simple_tensor
=
torch
.
Tensor
(
image_feature
)
kernel_info
=
info
.
kernel_infos
[
features
.
Image
]
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
kernel_name
)
info
.
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
(
"info"
,
"args_kwargs"
),
[
pytest
.
param
(
info
,
args_kwargs
,
id
=
f
"
{
info
.
dispatcher
.
__name__
}
-
{
idx
}
"
)
for
info
in
DISPATCHER_INFOS
for
idx
,
args_kwargs
in
enumerate
(
info
.
sample_inputs
(
features
.
Image
))
if
features
.
Image
in
info
.
kernels
and
info
.
pil_kernel_info
is
not
None
],
)
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_feature
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image_pil
(
image_feature
)
pil_kernel_info
=
info
.
pil_kernel_info
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
kernel_name
)
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
(
"info"
,
"args_kwargs"
),
[
pytest
.
param
(
info
,
args_kwargs
,
id
=
f
"
{
info
.
dispatcher
.
__name__
}
-
{
idx
}
"
)
for
info
in
DISPATCHER_INFOS
for
idx
,
args_kwargs
in
enumerate
(
info
.
sample_inputs
())
],
)
def
test_dispatch_feature
(
self
,
info
,
args_kwargs
,
spy_on
):
(
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
method
=
getattr
(
feature
,
info
.
method_name
)
feature_type
=
type
(
feature
)
spy
=
spy_on
(
method
,
module
=
feature_type
.
__module__
,
name
=
f
"
{
feature_type
.
__name__
}
.
{
info
.
method_name
}
"
)
info
.
dispatcher
(
feature
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"alias"
,
"target"
),
(
"alias"
,
"target"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment