Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
01d138d8
"examples/community/pipeline_stable_diffusion_boxdiff.py" did not exist on "be4afa0bb4384f201c8fe68af536faffefbae661"
Unverified
Commit
01d138d8
authored
Jan 20, 2023
by
Philip Meier
Committed by
GitHub
Jan 20, 2023
Browse files
update naming feature -> datapoint in prototype test suite (#7117)
parent
d7e5b6a1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
67 deletions
+69
-67
test/prototype_transforms_dispatcher_infos.py
test/prototype_transforms_dispatcher_infos.py
+15
-15
test/test_prototype_datapoints.py
test/test_prototype_datapoints.py
+1
-1
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+6
-6
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+10
-10
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+37
-35
No files found.
test/prototype_transforms_dispatcher_infos.py
View file @
01d138d8
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self
.
pil_kernel_info
=
pil_kernel_info
self
.
pil_kernel_info
=
pil_kernel_info
kernel_infos
=
{}
kernel_infos
=
{}
for
feature
_type
,
kernel
in
self
.
kernels
.
items
():
for
datapoint
_type
,
kernel
in
self
.
kernels
.
items
():
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"Can't register
{
kernel
.
__name__
}
for type
{
feature
_type
}
since there is no `KernelInfo` for it. "
f
"Can't register
{
kernel
.
__name__
}
for type
{
datapoint
_type
}
since there is no `KernelInfo` for it. "
f
"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
f
"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
)
kernel_infos
[
feature
_type
]
=
kernel_info
kernel_infos
[
datapoint
_type
]
=
kernel_info
self
.
kernel_infos
=
kernel_infos
self
.
kernel_infos
=
kernel_infos
def
sample_inputs
(
self
,
*
feature
_types
,
filter_metadata
=
True
):
def
sample_inputs
(
self
,
*
datapoint
_types
,
filter_metadata
=
True
):
for
feature_type
in
feature
_types
or
self
.
kernel_infos
.
keys
():
for
datapoint_type
in
datapoint
_types
or
self
.
kernel_infos
.
keys
():
kernel_info
=
self
.
kernel_infos
.
get
(
feature
_type
)
kernel_info
=
self
.
kernel_infos
.
get
(
datapoint
_type
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
...
@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
...
@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
yield
from
sample_inputs
yield
from
sample_inputs
else
:
else
:
for
args_kwargs
in
sample_inputs
:
for
args_kwargs
in
sample_inputs
:
for
attribute
in
feature
_type
.
__annotations__
.
keys
():
for
attribute
in
datapoint
_type
.
__annotations__
.
keys
():
if
attribute
in
args_kwargs
.
kwargs
:
if
attribute
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
attribute
]
del
args_kwargs
.
kwargs
[
attribute
]
...
@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
...
@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
)
)
skip_dispatch_
feature
=
TestMark
(
skip_dispatch_
datapoint
=
TestMark
(
(
"TestDispatchers"
,
"test_dispatch_
feature
"
),
(
"TestDispatchers"
,
"test_dispatch_
datapoint
"
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
feature
dispatch."
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
),
)
)
...
@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
...
@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
...
@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
...
@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
),
),
...
@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
...
@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
datapoints
.
Video
:
F
.
convert_dtype_video
,
datapoints
.
Video
:
F
.
convert_dtype_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
...
@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
...
@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
datapoints
.
Video
:
F
.
uniform_temporal_subsample_video
,
datapoints
.
Video
:
F
.
uniform_temporal_subsample_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
feature
,
skip_dispatch_
datapoint
,
],
],
),
),
]
]
test/test_prototype_datapoints.py
View file @
01d138d8
...
@@ -28,7 +28,7 @@ def test_to_wrapping():
...
@@ -28,7 +28,7 @@ def test_to_wrapping():
assert
label_to
.
categories
is
label
.
categories
assert
label_to
.
categories
is
label
.
categories
def
test_to_
feature
_reference
():
def
test_to_
datapoint
_reference
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
]).
to
(
torch
.
int32
)
label
=
datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
]).
to
(
torch
.
int32
)
...
...
test/test_prototype_transforms.py
View file @
01d138d8
...
@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
...
@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
def
test_
feature
s_image
(
self
,
p
):
def
test_
datapoint
s_image
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
...
@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
def
test_
feature
s_mask
(
self
,
p
):
def
test_
datapoint
s_mask
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
...
@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
def
test_
feature
s_bounding_box
(
self
,
p
):
def
test_
datapoint
s_bounding_box
(
self
,
p
):
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
transform
=
transforms
.
RandomHorizontalFlip
(
p
=
p
)
...
@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
...
@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
assert_equal
(
expected
,
pil_to_tensor
(
actual
))
def
test_
feature
s_image
(
self
,
p
):
def
test_
datapoint
s_image
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
...
@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
assert_equal
(
datapoints
.
Image
(
expected
),
actual
)
def
test_
feature
s_mask
(
self
,
p
):
def
test_
datapoint
s_mask
(
self
,
p
):
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
input
,
expected
=
self
.
input_expected_image_tensor
(
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
...
@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
assert_equal
(
datapoints
.
Mask
(
expected
),
actual
)
def
test_
feature
s_bounding_box
(
self
,
p
):
def
test_
datapoint
s_bounding_box
(
self
,
p
):
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
input
=
datapoints
.
BoundingBox
([
0
,
0
,
5
,
5
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
))
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
transform
=
transforms
.
RandomVerticalFlip
(
p
=
p
)
...
...
test/test_prototype_transforms_consistency.py
View file @
01d138d8
...
@@ -558,15 +558,15 @@ def check_call_consistency(
...
@@ -558,15 +558,15 @@ def check_call_consistency(
output_prototype_image
=
prototype_transform
(
image
)
output_prototype_image
=
prototype_transform
(
image
)
except
Exception
as
exc
:
except
Exception
as
exc
:
raise
AssertionError
(
raise
AssertionError
(
f
"Transforming a
feature image
with shape
{
image_repr
}
failed in the prototype transform with "
f
"Transforming a
image datapoint
with shape
{
image_repr
}
failed in the prototype transform with "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"`
feature
s.Image` path in `_transform`."
f
"`
datapoint
s.Image` path in `_transform`."
)
from
exc
)
from
exc
assert_close
(
assert_close
(
output_prototype_image
,
output_prototype_image
,
output_prototype_tensor
,
output_prototype_tensor
,
msg
=
lambda
msg
:
f
"Output for
feature
and tensor images is not equal:
\n\n
{
msg
}
"
,
msg
=
lambda
msg
:
f
"Output for
datapoint
and tensor images is not equal:
\n\n
{
msg
}
"
,
**
closeness_kwargs
,
**
closeness_kwargs
,
)
)
...
@@ -931,7 +931,7 @@ class TestRefDetTransforms:
...
@@ -931,7 +931,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
feature
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -939,7 +939,7 @@ class TestRefDetTransforms:
...
@@ -939,7 +939,7 @@ class TestRefDetTransforms:
if
with_mask
:
if
with_mask
:
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
yield
(
feature
_image
,
target
)
yield
(
datapoint
_image
,
target
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"t_ref, t, data_kwargs"
,
"t_ref, t, data_kwargs"
,
...
@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
...
@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
feature
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
feature
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
datapoint
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
feature_image
),
feature
_mask
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint
_mask
)
dp_ref
=
(
dp_ref
=
(
to_image_pil
(
feature
_image
)
if
supports_pil
else
feature
_image
.
as_subclass
(
torch
.
Tensor
),
to_image_pil
(
datapoint
_image
)
if
supports_pil
else
datapoint
_image
.
as_subclass
(
torch
.
Tensor
),
to_image_pil
(
feature
_mask
),
to_image_pil
(
datapoint
_mask
),
)
)
yield
dp
,
dp_ref
yield
dp
,
dp_ref
...
...
test/test_prototype_transforms_functional.py
View file @
01d138d8
...
@@ -162,7 +162,7 @@ class TestKernels:
...
@@ -162,7 +162,7 @@ class TestKernels:
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
feature
_type
=
(
datapoint
_type
=
(
datapoints
.
Image
datapoints
.
Image
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
batched_input
)
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
batched_input
)
else
type
(
batched_input
)
else
type
(
batched_input
)
...
@@ -178,10 +178,10 @@ class TestKernels:
...
@@ -178,10 +178,10 @@ class TestKernels:
# common ground.
# common ground.
datapoints
.
Mask
:
2
,
datapoints
.
Mask
:
2
,
datapoints
.
Video
:
4
,
datapoints
.
Video
:
4
,
}.
get
(
feature
_type
)
}.
get
(
datapoint
_type
)
if
data_dims
is
None
:
if
data_dims
is
None
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"The number of data dimensions cannot be determined for input of type
{
feature
_type
.
__name__
}
."
f
"The number of data dimensions cannot be determined for input of type
{
datapoint
_type
.
__name__
}
."
)
from
None
)
from
None
elif
batched_input
.
ndim
<=
data_dims
:
elif
batched_input
.
ndim
<=
data_dims
:
pytest
.
skip
(
"Input is not batched."
)
pytest
.
skip
(
"Input is not batched."
)
...
@@ -323,8 +323,8 @@ class TestDispatchers:
...
@@ -323,8 +323,8 @@ class TestDispatchers:
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
dispatcher
=
script
(
info
.
dispatcher
)
dispatcher
=
script
(
info
.
dispatcher
)
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
image_simple_tensor
=
torch
.
Tensor
(
image_
feature
)
image_simple_tensor
=
torch
.
Tensor
(
image_
datapoint
)
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
...
@@ -352,8 +352,8 @@ class TestDispatchers:
...
@@ -352,8 +352,8 @@ class TestDispatchers:
@
image_sample_inputs
@
image_sample_inputs
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_simple_tensor
=
torch
.
Tensor
(
image_
feature
)
image_simple_tensor
=
torch
.
Tensor
(
image_
datapoint
)
kernel_info
=
info
.
kernel_infos
[
datapoints
.
Image
]
kernel_info
=
info
.
kernel_infos
[
datapoints
.
Image
]
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
id
)
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
id
)
...
@@ -367,12 +367,12 @@ class TestDispatchers:
...
@@ -367,12 +367,12 @@ class TestDispatchers:
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
)
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_
feature
.
ndim
>
3
:
if
image_
datapoint
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image_pil
(
image_
feature
)
image_pil
=
F
.
to_image_pil
(
image_
datapoint
)
pil_kernel_info
=
info
.
pil_kernel_info
pil_kernel_info
=
info
.
pil_kernel_info
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
id
)
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
id
)
...
@@ -385,37 +385,39 @@ class TestDispatchers:
...
@@ -385,37 +385,39 @@ class TestDispatchers:
DISPATCHER_INFOS
,
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
)
def
test_dispatch_
feature
(
self
,
info
,
args_kwargs
,
spy_on
):
def
test_dispatch_
datapoint
(
self
,
info
,
args_kwargs
,
spy_on
):
(
feature
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
method_name
=
info
.
id
method_name
=
info
.
id
method
=
getattr
(
feature
,
method_name
)
method
=
getattr
(
datapoint
,
method_name
)
feature
_type
=
type
(
feature
)
datapoint
_type
=
type
(
datapoint
)
spy
=
spy_on
(
method
,
module
=
feature
_type
.
__module__
,
name
=
f
"
{
feature
_type
.
__name__
}
.
{
method_name
}
"
)
spy
=
spy_on
(
method
,
module
=
datapoint
_type
.
__module__
,
name
=
f
"
{
datapoint
_type
.
__name__
}
.
{
method_name
}
"
)
info
.
dispatcher
(
feature
,
*
other_args
,
**
kwargs
)
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
spy
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"dispatcher_info"
,
"
feature
_type"
,
"kernel_info"
),
(
"dispatcher_info"
,
"
datapoint
_type"
,
"kernel_info"
),
[
[
pytest
.
param
(
dispatcher_info
,
feature_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
feature_type
.
__name__
}
"
)
pytest
.
param
(
dispatcher_info
,
datapoint_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
datapoint_type
.
__name__
}
"
)
for
dispatcher_info
in
DISPATCHER_INFOS
for
dispatcher_info
in
DISPATCHER_INFOS
for
feature
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
for
datapoint
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
],
],
)
)
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
feature
_type
,
kernel_info
):
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
datapoint
_type
,
kernel_info
):
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
# We filter out metadata that is implicitly passed to the dispatcher through the input
feature
, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input
datapoint
, but has to be
# explicit passed to the kernel.
# explicit passed to the kernel.
feature
_type_metadata
=
feature
_type
.
__annotations__
.
keys
()
datapoint
_type_metadata
=
datapoint
_type
.
__annotations__
.
keys
()
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
feature
_type_metadata
]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
datapoint
_type_metadata
]
dispatcher_params
=
iter
(
dispatcher_params
)
dispatcher_params
=
iter
(
dispatcher_params
)
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
...
@@ -433,26 +435,26 @@ class TestDispatchers:
...
@@ -433,26 +435,26 @@ class TestDispatchers:
assert
dispatcher_param
==
kernel_param
assert
dispatcher_param
==
kernel_param
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_dispatcher_
feature
_signatures_consistency
(
self
,
info
):
def
test_dispatcher_
datapoint
_signatures_consistency
(
self
,
info
):
try
:
try
:
feature
_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
info
.
id
)
datapoint
_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
info
.
id
)
except
AttributeError
:
except
AttributeError
:
pytest
.
skip
(
"Dispatcher doesn't support arbitrary
feature
dispatch."
)
pytest
.
skip
(
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
)
dispatcher_signature
=
inspect
.
signature
(
info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
feature
_signature
=
inspect
.
signature
(
feature
_method
)
datapoint
_signature
=
inspect
.
signature
(
datapoint
_method
)
feature
_params
=
list
(
feature
_signature
.
parameters
.
values
())[
1
:]
datapoint
_params
=
list
(
datapoint
_signature
.
parameters
.
values
())[
1
:]
# Because we use `from __future__ import annotations` inside the module where `
feature
s._datapoint` is
defined,
# Because we use `from __future__ import annotations` inside the module where `
datapoint
s._datapoint` is
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the
natively
#
defined,
the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# concrete dispatcher annotations.
#
natively
concrete dispatcher annotations.
feature
_annotations
=
get_type_hints
(
feature
_method
)
datapoint
_annotations
=
get_type_hints
(
datapoint
_method
)
for
param
in
feature
_params
:
for
param
in
datapoint
_params
:
param
.
_annotation
=
feature
_annotations
[
param
.
name
]
param
.
_annotation
=
datapoint
_annotations
[
param
.
name
]
assert
dispatcher_params
==
feature
_params
assert
dispatcher_params
==
datapoint
_params
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_unkown_type
(
self
,
info
):
def
test_unkown_type
(
self
,
info
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment