Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
01d138d8
Unverified
Commit
01d138d8
authored
Jan 20, 2023
by
Philip Meier
Committed by
GitHub
Jan 20, 2023
Browse files
update naming feature -> datapoint in prototype test suite (#7117)
parent
d7e5b6a1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
67 deletions
+69
-67
test/prototype_transforms_dispatcher_infos.py
test/prototype_transforms_dispatcher_infos.py
+15
-15
test/test_prototype_datapoints.py
test/test_prototype_datapoints.py
+1
-1
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+6
-6
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+10
-10
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+37
-35
No files found.
test/prototype_transforms_dispatcher_infos.py
View file @
01d138d8
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self
.
pil_kernel_info
=
pil_kernel_info
self
.
pil_kernel_info
=
pil_kernel_info
kernel_infos
=
{}
kernel_infos
=
{}
for
feature
_type
,
kernel
in
self
.
kernels
.
items
():
for
datapoint
_type
,
kernel
in
self
.
kernels
.
items
():
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"Can't register
{
kernel
.
__name__
}
for type
{
feature
_type
}
since there is no `KernelInfo` for it. "
f
"Can't register
{
kernel
.
__name__
}
for type
{
datapoint
_type
}
since there is no `KernelInfo` for it. "
f
"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
f
"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
)
kernel_infos
[
feature
_type
]
=
kernel_info
kernel_infos
[
datapoint
_type
]
=
kernel_info
self
.
kernel_infos
=
kernel_infos
self
.
kernel_infos
=
kernel_infos
def
sample_inputs
(
self
,
*
feature
_types
,
filter_metadata
=
True
):
def
sample_inputs
(
self
,
*
datapoint
_types
,
filter_metadata
=
True
):
for
feature_type
in
feature
_types
or
self
.
kernel_infos
.
keys
():
for
datapoint_type
in
datapoint
_types
or
self
.
kernel_infos
.
keys
():
kernel_info
=
self
.
kernel_infos
.
get
(
feature
_type
)
kernel_info
=
self
.
kernel_infos
.
get
(
datapoint
_type
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
...
@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
...
@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
yield
from
sample_inputs
yield
from
sample_inputs
else
:
else
:
for
args_kwargs
in
sample_inputs
:
for
args_kwargs
in
sample_inputs
:
for
attribute
in
feature
_type
.
__annotations__
.
keys
():
for
attribute
in
datapoint
_type
.
__annotations__
.
keys
():
if
attribute
in
args_kwargs
.
kwargs
:
if
attribute
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
attribute
]
del
args_kwargs
.
kwargs
[
attribute
]
...
@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
...
@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
)
)
skip_dispatch_
feature
=
TestMark
(
skip_dispatch_
datapoint
=
TestMark
(
(
"TestDispatchers"
,
"test_dispatch_
feature
"
),
(
"TestDispatchers"
,
"test_dispatch_
datapoint
"
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
feature
dispatch."
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
),
)
)
...
@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
...
@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
...
@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
...
@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
),
),
...
@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
...
@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
datapoints
.
Video
:
F
.
convert_dtype_video
,
datapoints
.
Video
:
F
.
convert_dtype_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
...
@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
datapoints
.
Video
:
F
.
uniform_temporal_subsample_video
,
datapoints
.
Video
:
F
.
uniform_temporal_subsample_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
]
]
test/test_prototype_datapoints.py
View file @
01d138d8
...
@@ -28,7 +28,7 @@ def test_to_wrapping():
...
@@ -28,7 +28,7 @@ def test_to_wrapping():
assert
label_to
.
categories
is
label
.
categories
assert
label_to
.
categories
is
label
.
categories
def
test_to_
feature
_reference
():
def
test_to_
datapoint
_reference
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
]).
to
(
torch
.
int32
)
label
=
datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
]).
to
(
torch
.
int32
)
...
...
test/test_prototype_transforms.py
View file @
01d138d8
...
@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
...
@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
def
test_
feature
s_image
(
self
,
p
):
def
test_
datapoint
s_image
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
...
@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
def
test_
feature
s_mask
(
self
,
p
):
def
test_
datapoint
s_mask
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
...
@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
def
test_
feature
s_bounding_box
(
self
,
p
):
def
test_
datapoint
s_bounding_box
(
self
,
p
):
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
...
@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
def
test_
feature
s_image
(
self
,
p
):
def
test_
datapoint
s_image
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
...
@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
def
test_
feature
s_mask
(
self
,
p
):
def
test_
datapoint
s_mask
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
...
@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
def
test_
feature
s_bounding_box
(
self
,
p
):
def
test_
datapoint
s_bounding_box
(
self
,
p
):
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
...
test/test_prototype_transforms_consistency.py
View file @
01d138d8
...
@@ -558,15 +558,15 @@ def check_call_consistency(
...
@@ -558,15 +558,15 @@ def check_call_consistency(
output_prototype_image
=
prototype_transform
(
image
)
output_prototype_image
=
prototype_transform
(
image
)
except
Exception
as
exc
:
except
Exception
as
exc
:
raise
AssertionError
(
raise
AssertionError
(
f
"Transforming a
feature image
with shape
{
image_repr
}
failed in the prototype transform with "
f
"Transforming a
image datapoint
with shape
{
image_repr
}
failed in the prototype transform with "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"`
feature
s.Image` path in `_transform`."
f
"`
datapoint
s.Image` path in `_transform`."
)
from
exc
)
from
exc
assert_close
(
assert_close
(
output_prototype_image
,
output_prototype_image
,
output_prototype_tensor
,
output_prototype_tensor
,
msg
=
lambda
msg
:
f
"Output for
feature
and tensor images is not equal:
\n\n
{
msg
}
"
,
msg
=
lambda
msg
:
f
"Output for
datapoint
and tensor images is not equal:
\n\n
{
msg
}
"
,
**
closeness_kwargs
,
**
closeness_kwargs
,
)
)
...
@@ -931,7 +931,7 @@ class TestRefDetTransforms:
...
@@ -931,7 +931,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
feature
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -939,7 +939,7 @@ class TestRefDetTransforms:
...
@@ -939,7 +939,7 @@ class TestRefDetTransforms:
if
with_mask
:
if
with_mask
:
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
yield
(
feature
_image
,
target
)
yield
(
datapoint
_image
,
target
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"t_ref, t, data_kwargs"
,
"t_ref, t, data_kwargs"
,
...
@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
...
@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
feature
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
feature
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
datapoint
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
feature_image
),
feature
_mask
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint
_mask
)
dp_ref
=
(
dp_ref
=
(
to_image_pil
(
feature
_image
)
if
supports_pil
else
feature
_image
.
as_subclass
(
torch
.
Tensor
),
to_image_pil
(
datapoint
_image
)
if
supports_pil
else
datapoint
_image
.
as_subclass
(
torch
.
Tensor
),
to_image_pil
(
feature
_mask
),
to_image_pil
(
datapoint
_mask
),
)
)
yield
dp
,
dp_ref
yield
dp
,
dp_ref
...
...
test/test_prototype_transforms_functional.py
View file @
01d138d8
...
@@ -162,7 +162,7 @@ class TestKernels:
...
@@ -162,7 +162,7 @@ class TestKernels:
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
feature
_type
=
(
datapoint
_type
=
(
datapoints
.
Image
datapoints
.
Image
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
batched_input
)
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
batched_input
)
else
type
(
batched_input
)
else
type
(
batched_input
)
...
@@ -178,10 +178,10 @@ class TestKernels:
...
@@ -178,10 +178,10 @@ class TestKernels:
# common ground.
# common ground.
datapoints
.
Mask
:
2
,
datapoints
.
Mask
:
2
,
datapoints
.
Video
:
4
,
datapoints
.
Video
:
4
,
}.
get
(
feature
_type
)
}.
get
(
datapoint
_type
)
if
data_dims
is
None
:
if
data_dims
is
None
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"The number of data dimensions cannot be determined for input of type
{
feature
_type
.
__name__
}
."
f
"The number of data dimensions cannot be determined for input of type
{
datapoint
_type
.
__name__
}
."
)
from
None
)
from
None
elif
batched_input
.
ndim
<=
data_dims
:
elif
batched_input
.
ndim
<=
data_dims
:
pytest
.
skip
(
"Input is not batched."
)
pytest
.
skip
(
"Input is not batched."
)
...
@@ -323,8 +323,8 @@ class TestDispatchers:
...
@@ -323,8 +323,8 @@ class TestDispatchers:
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
dispatcher
=
script
(
info
.
dispatcher
)
dispatcher
=
script
(
info
.
dispatcher
)
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
image_simple_tensor
=
torch
.
Tensor
(
image_
feature
)
image_simple_tensor
=
torch
.
Tensor
(
image_
datapoint
)
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
...
@@ -352,8 +352,8 @@ class TestDispatchers:
...
@@ -352,8 +352,8 @@ class TestDispatchers:
@
image_sample_inputs
@
image_sample_inputs
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_simple_tensor
=
torch
.
Tensor
(
image_
feature
)
image_simple_tensor
=
torch
.
Tensor
(
image_
datapoint
)
kernel_info
=
info
.
kernel_infos
[
datapoints
.
Image
]
kernel_info
=
info
.
kernel_infos
[
datapoints
.
Image
]
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
id
)
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
id
)
...
@@ -367,12 +367,12 @@ class TestDispatchers:
...
@@ -367,12 +367,12 @@ class TestDispatchers:
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
)
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_
feature
.
ndim
>
3
:
if
image_
datapoint
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image_pil
(
image_
feature
)
image_pil
=
F
.
to_image_pil
(
image_
datapoint
)
pil_kernel_info
=
info
.
pil_kernel_info
pil_kernel_info
=
info
.
pil_kernel_info
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
id
)
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
id
)
...
@@ -385,37 +385,39 @@ class TestDispatchers:
...
@@ -385,37 +385,39 @@ class TestDispatchers:
DISPATCHER_INFOS
,
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
)
def
test_dispatch_
feature
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_
datapoint
(
self
,
info
,
args_kwargs
,
spy_on
):
(
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
method_name
=
info
.
id
method_name
=
info
.
id
method
=
getattr
(
feature
,
method_name
)
method
=
getattr
(
datapoint
,
method_name
)
feature
_type
=
type
(
feature
)
datapoint
_type
=
type
(
datapoint
)
spy
=
spy_on
(
method
,
module
=
feature
_type
.
__module__
,
name
=
f
"
{
feature
_type
.
__name__
}
.
{
method_name
}
"
)
spy
=
spy_on
(
method
,
module
=
datapoint
_type
.
__module__
,
name
=
f
"
{
datapoint
_type
.
__name__
}
.
{
method_name
}
"
)
info
.
dispatcher
(
feature
,
*
other_args
,
**
kwargs
)
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
spy
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"dispatcher_info"
,
"
feature
_type"
,
"kernel_info"
),
(
"dispatcher_info"
,
"
datapoint
_type"
,
"kernel_info"
),
[
[
pytest
.
param
(
dispatcher_info
,
feature_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
feature_type
.
__name__
}
"
)
pytest
.
param
(
dispatcher_info
,
datapoint_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
datapoint_type
.
__name__
}
"
)
for
dispatcher_info
in
DISPATCHER_INFOS
for
dispatcher_info
in
DISPATCHER_INFOS
for
feature
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
for
datapoint
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
],
],
)
)
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
feature
_type
,
kernel_info
):
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
datapoint
_type
,
kernel_info
):
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
# We filter out metadata that is implicitly passed to the dispatcher through the input
feature
, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input
datapoint
, but has to be
# explicit passed to the kernel.
# explicit passed to the kernel.
feature
_type_metadata
=
feature
_type
.
__annotations__
.
keys
()
datapoint
_type_metadata
=
datapoint
_type
.
__annotations__
.
keys
()
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
feature
_type_metadata
]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
datapoint
_type_metadata
]
dispatcher_params
=
iter
(
dispatcher_params
)
dispatcher_params
=
iter
(
dispatcher_params
)
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
...
@@ -433,26 +435,26 @@ class TestDispatchers:
...
@@ -433,26 +435,26 @@ class TestDispatchers:
assert
dispatcher_param
==
kernel_param
assert
dispatcher_param
==
kernel_param
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_dispatcher_
feature
_signatures_consistency
(
self
,
info
):
def
test_dispatcher_
datapoint
_signatures_consistency
(
self
,
info
):
try
:
try
:
feature
_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
info
.
id
)
datapoint
_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
info
.
id
)
except
AttributeError
:
except
AttributeError
:
pytest
.
skip
(
"Dispatcher doesn't support arbitrary
feature
dispatch."
)
pytest
.
skip
(
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
)
dispatcher_signature
=
inspect
.
signature
(
info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
feature
_signature
=
inspect
.
signature
(
feature
_method
)
datapoint
_signature
=
inspect
.
signature
(
datapoint
_method
)
feature
_params
=
list
(
feature
_signature
.
parameters
.
values
())[
1
:]
datapoint
_params
=
list
(
datapoint
_signature
.
parameters
.
values
())[
1
:]
# Because we use `from __future__ import annotations` inside the module where `
feature
s._datapoint` is
defined,
# Because we use `from __future__ import annotations` inside the module where `
datapoint
s._datapoint` is
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the
natively
#
defined,
the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# concrete dispatcher annotations.
#
natively
concrete dispatcher annotations.
feature
_annotations
=
get_type_hints
(
feature
_method
)
datapoint
_annotations
=
get_type_hints
(
datapoint
_method
)
for
param
in
feature
_params
:
for
param
in
datapoint
_params
:
param
.
_annotation
=
feature
_annotations
[
param
.
name
]
param
.
_annotation
=
datapoint
_annotations
[
param
.
name
]
assert
dispatcher_params
==
feature
_params
assert
dispatcher_params
==
datapoint
_params
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_unkown_type
(
self
,
info
):
def
test_unkown_type
(
self
,
info
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment