Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
8acf1ca2
Unverified
Commit
8acf1ca2
authored
Sep 09, 2022
by
Philip Meier
Committed by
GitHub
Sep 09, 2022
Browse files
extract common utils for prototype transform tests (#6552)
parent
b5c961d4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
246 additions
and
243 deletions
+246
-243
test/prototype_common_utils.py
test/prototype_common_utils.py
+199
-0
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+1
-1
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+2
-42
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+43
-199
test/test_prototype_transforms_utils.py
test/test_prototype_transforms_utils.py
+1
-1
No files found.
test/prototype_common_utils.py
0 → 100644
View file @
8acf1ca2
import
functools
import
itertools
import
PIL.Image
import
pytest
import
torch
import
torch.testing
from
torch.nn.functional
import
one_hot
from
torch.testing._comparison
import
assert_equal
as
_assert_equal
,
TensorLikePair
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms.functional
import
to_image_tensor
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
class
ImagePair
(
TensorLikePair
):
def
_process_inputs
(
self
,
actual
,
expected
,
*
,
id
,
allow_subclasses
):
return
super
().
_process_inputs
(
*
[
to_image_tensor
(
input
)
if
isinstance
(
input
,
PIL
.
Image
.
Image
)
else
input
for
input
in
[
actual
,
expected
]],
id
=
id
,
allow_subclasses
=
allow_subclasses
,
)
assert_equal
=
functools
.
partial
(
_assert_equal
,
pair_types
=
[
ImagePair
],
rtol
=
0
,
atol
=
0
)
class
ArgsKwargs
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
args
=
args
self
.
kwargs
=
kwargs
def
__iter__
(
self
):
yield
self
.
args
yield
self
.
kwargs
def
__str__
(
self
):
def
short_repr
(
obj
,
max
=
20
):
repr_
=
repr
(
obj
)
if
len
(
repr_
)
<=
max
:
return
repr_
return
f
"
{
repr_
[:
max
//
2
]
}
...
{
repr_
[
-
(
max
//
2
-
3
):]
}
"
return
", "
.
join
(
itertools
.
chain
(
[
short_repr
(
arg
)
for
arg
in
self
.
args
],
[
f
"
{
param
}
=
{
short_repr
(
kwarg
)
}
"
for
param
,
kwarg
in
self
.
kwargs
.
items
()],
)
)
make_tensor
=
functools
.
partial
(
torch
.
testing
.
make_tensor
,
device
=
"cpu"
)
def
make_image
(
size
=
None
,
*
,
color_space
,
extra_dims
=
(),
dtype
=
torch
.
float32
,
constant_alpha
=
True
):
size
=
size
or
torch
.
randint
(
16
,
33
,
(
2
,)).
tolist
()
try
:
num_channels
=
{
features
.
ColorSpace
.
GRAY
:
1
,
features
.
ColorSpace
.
GRAY_ALPHA
:
2
,
features
.
ColorSpace
.
RGB
:
3
,
features
.
ColorSpace
.
RGB_ALPHA
:
4
,
}[
color_space
]
except
KeyError
as
error
:
raise
pytest
.
UsageError
()
from
error
shape
=
(
*
extra_dims
,
num_channels
,
*
size
)
max_value
=
get_max_value
(
dtype
)
data
=
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
)
if
color_space
in
{
features
.
ColorSpace
.
GRAY_ALPHA
,
features
.
ColorSpace
.
RGB_ALPHA
}
and
constant_alpha
:
data
[...,
-
1
,
:,
:]
=
max_value
return
features
.
Image
(
data
,
color_space
=
color_space
)
make_grayscale_image
=
functools
.
partial
(
make_image
,
color_space
=
features
.
ColorSpace
.
GRAY
)
make_rgb_image
=
functools
.
partial
(
make_image
,
color_space
=
features
.
ColorSpace
.
RGB
)
def
make_images
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
)),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
GRAY_ALPHA
,
features
.
ColorSpace
.
RGB
,
features
.
ColorSpace
.
RGB_ALPHA
,
),
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
size
,
color_space
,
dtype
in
itertools
.
product
(
sizes
,
color_spaces
,
dtypes
):
yield
make_image
(
size
,
color_space
=
color_space
,
dtype
=
dtype
)
for
color_space
,
dtype
,
extra_dims_
in
itertools
.
product
(
color_spaces
,
dtypes
,
extra_dims
):
yield
make_image
(
size
=
sizes
[
0
],
color_space
=
color_space
,
extra_dims
=
extra_dims_
,
dtype
=
dtype
)
def
randint_with_tensor_bounds
(
arg1
,
arg2
=
None
,
**
kwargs
):
low
,
high
=
torch
.
broadcast_tensors
(
*
[
torch
.
as_tensor
(
arg
)
for
arg
in
((
0
,
arg1
)
if
arg2
is
None
else
(
arg1
,
arg2
))]
)
return
torch
.
stack
(
[
torch
.
randint
(
low_scalar
,
high_scalar
,
(),
**
kwargs
)
for
low_scalar
,
high_scalar
in
zip
(
low
.
flatten
().
tolist
(),
high
.
flatten
().
tolist
())
]
).
reshape
(
low
.
shape
)
def
make_bounding_box
(
*
,
format
,
image_size
=
(
32
,
32
),
extra_dims
=
(),
dtype
=
torch
.
int64
):
if
isinstance
(
format
,
str
):
format
=
features
.
BoundingBoxFormat
[
format
]
if
any
(
dim
==
0
for
dim
in
extra_dims
):
return
features
.
BoundingBox
(
torch
.
empty
(
*
extra_dims
,
4
),
format
=
format
,
image_size
=
image_size
)
height
,
width
=
image_size
if
format
==
features
.
BoundingBoxFormat
.
XYXY
:
x1
=
torch
.
randint
(
0
,
width
//
2
,
extra_dims
)
y1
=
torch
.
randint
(
0
,
height
//
2
,
extra_dims
)
x2
=
randint_with_tensor_bounds
(
x1
+
1
,
width
-
x1
)
+
x1
y2
=
randint_with_tensor_bounds
(
y1
+
1
,
height
-
y1
)
+
y1
parts
=
(
x1
,
y1
,
x2
,
y2
)
elif
format
==
features
.
BoundingBoxFormat
.
XYWH
:
x
=
torch
.
randint
(
0
,
width
//
2
,
extra_dims
)
y
=
torch
.
randint
(
0
,
height
//
2
,
extra_dims
)
w
=
randint_with_tensor_bounds
(
1
,
width
-
x
)
h
=
randint_with_tensor_bounds
(
1
,
height
-
y
)
parts
=
(
x
,
y
,
w
,
h
)
elif
format
==
features
.
BoundingBoxFormat
.
CXCYWH
:
cx
=
torch
.
randint
(
1
,
width
-
1
,
())
cy
=
torch
.
randint
(
1
,
height
-
1
,
())
w
=
randint_with_tensor_bounds
(
1
,
torch
.
minimum
(
cx
,
width
-
cx
)
+
1
)
h
=
randint_with_tensor_bounds
(
1
,
torch
.
minimum
(
cy
,
height
-
cy
)
+
1
)
parts
=
(
cx
,
cy
,
w
,
h
)
else
:
raise
pytest
.
UsageError
()
return
features
.
BoundingBox
(
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
),
format
=
format
,
image_size
=
image_size
)
make_xyxy_bounding_box
=
functools
.
partial
(
make_bounding_box
,
format
=
features
.
BoundingBoxFormat
.
XYXY
)
def
make_bounding_boxes
(
formats
=
(
features
.
BoundingBoxFormat
.
XYXY
,
features
.
BoundingBoxFormat
.
XYWH
,
features
.
BoundingBoxFormat
.
CXCYWH
),
image_sizes
=
((
32
,
32
),),
dtypes
=
(
torch
.
int64
,
torch
.
float32
),
extra_dims
=
((
0
,),
(),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
format
,
image_size
,
dtype
in
itertools
.
product
(
formats
,
image_sizes
,
dtypes
):
yield
make_bounding_box
(
format
=
format
,
image_size
=
image_size
,
dtype
=
dtype
)
for
format
,
extra_dims_
in
itertools
.
product
(
formats
,
extra_dims
):
yield
make_bounding_box
(
format
=
format
,
extra_dims
=
extra_dims_
)
def
make_label
(
size
=
(),
*
,
categories
=
(
"category0"
,
"category1"
)):
return
features
.
Label
(
torch
.
randint
(
0
,
len
(
categories
)
if
categories
else
10
,
size
),
categories
=
categories
)
def
make_one_hot_label
(
*
args
,
**
kwargs
):
label
=
make_label
(
*
args
,
**
kwargs
)
return
features
.
OneHotLabel
(
one_hot
(
label
,
num_classes
=
len
(
label
.
categories
)),
categories
=
label
.
categories
)
def
make_one_hot_labels
(
*
,
num_categories
=
(
1
,
2
,
10
),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
num_categories_
in
num_categories
:
yield
make_one_hot_label
(
categories
=
[
f
"category
{
idx
}
"
for
idx
in
range
(
num_categories_
)])
for
extra_dims_
in
extra_dims
:
yield
make_one_hot_label
(
extra_dims_
)
def
make_segmentation_mask
(
size
=
None
,
*
,
num_objects
=
None
,
extra_dims
=
(),
dtype
=
torch
.
uint8
):
size
=
size
if
size
is
not
None
else
torch
.
randint
(
16
,
33
,
(
2
,)).
tolist
()
num_objects
=
num_objects
if
num_objects
is
not
None
else
int
(
torch
.
randint
(
1
,
11
,
()))
shape
=
(
*
extra_dims
,
num_objects
,
*
size
)
data
=
make_tensor
(
shape
,
low
=
0
,
high
=
2
,
dtype
=
dtype
)
return
features
.
SegmentationMask
(
data
)
def
make_segmentation_masks
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
)),
dtypes
=
(
torch
.
uint8
,),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
num_objects
=
(
1
,
0
,
10
),
):
for
size
,
dtype
,
extra_dims_
in
itertools
.
product
(
sizes
,
dtypes
,
extra_dims
):
yield
make_segmentation_mask
(
size
=
size
,
dtype
=
dtype
,
extra_dims
=
extra_dims_
)
for
dtype
,
extra_dims_
,
num_objects_
in
itertools
.
product
(
dtypes
,
extra_dims
,
num_objects
):
yield
make_segmentation_mask
(
size
=
sizes
[
0
],
num_objects
=
num_objects_
,
dtype
=
dtype
,
extra_dims
=
extra_dims_
)
test/test_prototype_transforms.py
View file @
8acf1ca2
...
@@ -7,7 +7,7 @@ import PIL.Image
...
@@ -7,7 +7,7 @@ import PIL.Image
import
pytest
import
pytest
import
torch
import
torch
from
common_utils
import
assert_equal
,
cpu_and_gpu
from
common_utils
import
assert_equal
,
cpu_and_gpu
from
test_
prototype_
transforms_functional
import
(
from
prototype_
common_utils
import
(
make_bounding_box
,
make_bounding_box
,
make_bounding_boxes
,
make_bounding_boxes
,
make_image
,
make_image
,
...
...
test/test_prototype_transforms_consistency.py
View file @
8acf1ca2
import
enum
import
enum
import
functools
import
inspect
import
inspect
import
itertools
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
import
PIL.Image
import
pytest
import
pytest
import
torch
import
torch
from
test_prototype_transforms_functional
import
make_images
from
prototype_common_utils
import
ArgsKwargs
,
assert_equal
,
make_images
from
torch.testing._comparison
import
assert_equal
as
_assert_equal
,
TensorLikePair
from
torchvision
import
transforms
as
legacy_transforms
from
torchvision
import
transforms
as
legacy_transforms
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
features
,
transforms
as
prototype_transforms
from
torchvision.prototype
import
features
,
transforms
as
prototype_transforms
from
torchvision.prototype.transforms.functional
import
to_image_pil
,
to_image_tensor
from
torchvision.prototype.transforms.functional
import
to_image_pil
class
ImagePair
(
TensorLikePair
):
def
_process_inputs
(
self
,
actual
,
expected
,
*
,
id
,
allow_subclasses
):
return
super
().
_process_inputs
(
*
[
to_image_tensor
(
input
)
if
isinstance
(
input
,
PIL
.
Image
.
Image
)
else
input
for
input
in
[
actual
,
expected
]],
id
=
id
,
allow_subclasses
=
allow_subclasses
,
)
assert_equal
=
functools
.
partial
(
_assert_equal
,
pair_types
=
[
ImagePair
],
rtol
=
0
,
atol
=
0
)
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
features
.
ColorSpace
.
RGB
],
extra_dims
=
[(
4
,)])
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
features
.
ColorSpace
.
RGB
],
extra_dims
=
[(
4
,)])
class
ArgsKwargs
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
args
=
args
self
.
kwargs
=
kwargs
def
__iter__
(
self
):
yield
self
.
args
yield
self
.
kwargs
def
__str__
(
self
):
def
short_repr
(
obj
,
max
=
20
):
repr_
=
repr
(
obj
)
if
len
(
repr_
)
<=
max
:
return
repr_
return
f
"
{
repr_
[:
max
//
2
]
}
...
{
repr_
[
-
(
max
//
2
-
3
):]
}
"
return
", "
.
join
(
itertools
.
chain
(
[
short_repr
(
arg
)
for
arg
in
self
.
args
],
[
f
"
{
param
}
=
{
short_repr
(
kwarg
)
}
"
for
param
,
kwarg
in
self
.
kwargs
.
items
()],
)
)
class
ConsistencyConfig
:
class
ConsistencyConfig
:
def
__init__
(
def
__init__
(
self
,
self
,
...
...
test/test_prototype_transforms_functional.py
View file @
8acf1ca2
import
functools
import
itertools
import
itertools
import
math
import
math
import
os
import
os
...
@@ -9,167 +8,12 @@ import pytest
...
@@ -9,167 +8,12 @@ import pytest
import
torch.testing
import
torch.testing
import
torchvision.prototype.transforms.functional
as
F
import
torchvision.prototype.transforms.functional
as
F
from
common_utils
import
cpu_and_gpu
from
common_utils
import
cpu_and_gpu
from
prototype_common_utils
import
ArgsKwargs
,
make_bounding_boxes
,
make_image
,
make_images
,
make_segmentation_masks
from
torch
import
jit
from
torch
import
jit
from
torch.nn.functional
import
one_hot
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms.functional._geometry
import
_center_crop_compute_padding
from
torchvision.prototype.transforms.functional._geometry
import
_center_crop_compute_padding
from
torchvision.prototype.transforms.functional._meta
import
convert_bounding_box_format
from
torchvision.prototype.transforms.functional._meta
import
convert_bounding_box_format
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
make_tensor
=
functools
.
partial
(
torch
.
testing
.
make_tensor
,
device
=
"cpu"
)
def
make_image
(
size
=
None
,
*
,
color_space
,
extra_dims
=
(),
dtype
=
torch
.
float32
,
constant_alpha
=
True
):
size
=
size
or
torch
.
randint
(
16
,
33
,
(
2
,)).
tolist
()
try
:
num_channels
=
{
features
.
ColorSpace
.
GRAY
:
1
,
features
.
ColorSpace
.
GRAY_ALPHA
:
2
,
features
.
ColorSpace
.
RGB
:
3
,
features
.
ColorSpace
.
RGB_ALPHA
:
4
,
}[
color_space
]
except
KeyError
as
error
:
raise
pytest
.
UsageError
()
from
error
shape
=
(
*
extra_dims
,
num_channels
,
*
size
)
max_value
=
get_max_value
(
dtype
)
data
=
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
)
if
color_space
in
{
features
.
ColorSpace
.
GRAY_ALPHA
,
features
.
ColorSpace
.
RGB_ALPHA
}
and
constant_alpha
:
data
[...,
-
1
,
:,
:]
=
max_value
return
features
.
Image
(
data
,
color_space
=
color_space
)
make_grayscale_image
=
functools
.
partial
(
make_image
,
color_space
=
features
.
ColorSpace
.
GRAY
)
make_rgb_image
=
functools
.
partial
(
make_image
,
color_space
=
features
.
ColorSpace
.
RGB
)
def
make_images
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
)),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
GRAY_ALPHA
,
features
.
ColorSpace
.
RGB
,
features
.
ColorSpace
.
RGB_ALPHA
,
),
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
size
,
color_space
,
dtype
in
itertools
.
product
(
sizes
,
color_spaces
,
dtypes
):
yield
make_image
(
size
,
color_space
=
color_space
,
dtype
=
dtype
)
for
color_space
,
dtype
,
extra_dims_
in
itertools
.
product
(
color_spaces
,
dtypes
,
extra_dims
):
yield
make_image
(
size
=
sizes
[
0
],
color_space
=
color_space
,
extra_dims
=
extra_dims_
,
dtype
=
dtype
)
def
randint_with_tensor_bounds
(
arg1
,
arg2
=
None
,
**
kwargs
):
low
,
high
=
torch
.
broadcast_tensors
(
*
[
torch
.
as_tensor
(
arg
)
for
arg
in
((
0
,
arg1
)
if
arg2
is
None
else
(
arg1
,
arg2
))]
)
return
torch
.
stack
(
[
torch
.
randint
(
low_scalar
,
high_scalar
,
(),
**
kwargs
)
for
low_scalar
,
high_scalar
in
zip
(
low
.
flatten
().
tolist
(),
high
.
flatten
().
tolist
())
]
).
reshape
(
low
.
shape
)
def
make_bounding_box
(
*
,
format
,
image_size
=
(
32
,
32
),
extra_dims
=
(),
dtype
=
torch
.
int64
):
if
isinstance
(
format
,
str
):
format
=
features
.
BoundingBoxFormat
[
format
]
if
any
(
dim
==
0
for
dim
in
extra_dims
):
return
features
.
BoundingBox
(
torch
.
empty
(
*
extra_dims
,
4
),
format
=
format
,
image_size
=
image_size
)
height
,
width
=
image_size
if
format
==
features
.
BoundingBoxFormat
.
XYXY
:
x1
=
torch
.
randint
(
0
,
width
//
2
,
extra_dims
)
y1
=
torch
.
randint
(
0
,
height
//
2
,
extra_dims
)
x2
=
randint_with_tensor_bounds
(
x1
+
1
,
width
-
x1
)
+
x1
y2
=
randint_with_tensor_bounds
(
y1
+
1
,
height
-
y1
)
+
y1
parts
=
(
x1
,
y1
,
x2
,
y2
)
elif
format
==
features
.
BoundingBoxFormat
.
XYWH
:
x
=
torch
.
randint
(
0
,
width
//
2
,
extra_dims
)
y
=
torch
.
randint
(
0
,
height
//
2
,
extra_dims
)
w
=
randint_with_tensor_bounds
(
1
,
width
-
x
)
h
=
randint_with_tensor_bounds
(
1
,
height
-
y
)
parts
=
(
x
,
y
,
w
,
h
)
elif
format
==
features
.
BoundingBoxFormat
.
CXCYWH
:
cx
=
torch
.
randint
(
1
,
width
-
1
,
())
cy
=
torch
.
randint
(
1
,
height
-
1
,
())
w
=
randint_with_tensor_bounds
(
1
,
torch
.
minimum
(
cx
,
width
-
cx
)
+
1
)
h
=
randint_with_tensor_bounds
(
1
,
torch
.
minimum
(
cy
,
height
-
cy
)
+
1
)
parts
=
(
cx
,
cy
,
w
,
h
)
else
:
raise
pytest
.
UsageError
()
return
features
.
BoundingBox
(
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
),
format
=
format
,
image_size
=
image_size
)
make_xyxy_bounding_box
=
functools
.
partial
(
make_bounding_box
,
format
=
features
.
BoundingBoxFormat
.
XYXY
)
def
make_bounding_boxes
(
formats
=
(
features
.
BoundingBoxFormat
.
XYXY
,
features
.
BoundingBoxFormat
.
XYWH
,
features
.
BoundingBoxFormat
.
CXCYWH
),
image_sizes
=
((
32
,
32
),),
dtypes
=
(
torch
.
int64
,
torch
.
float32
),
extra_dims
=
((
0
,),
(),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
format
,
image_size
,
dtype
in
itertools
.
product
(
formats
,
image_sizes
,
dtypes
):
yield
make_bounding_box
(
format
=
format
,
image_size
=
image_size
,
dtype
=
dtype
)
for
format
,
extra_dims_
in
itertools
.
product
(
formats
,
extra_dims
):
yield
make_bounding_box
(
format
=
format
,
extra_dims
=
extra_dims_
)
def
make_label
(
size
=
(),
*
,
categories
=
(
"category0"
,
"category1"
)):
return
features
.
Label
(
torch
.
randint
(
0
,
len
(
categories
)
if
categories
else
10
,
size
),
categories
=
categories
)
def
make_one_hot_label
(
*
args
,
**
kwargs
):
label
=
make_label
(
*
args
,
**
kwargs
)
return
features
.
OneHotLabel
(
one_hot
(
label
,
num_classes
=
len
(
label
.
categories
)),
categories
=
label
.
categories
)
def
make_one_hot_labels
(
*
,
num_categories
=
(
1
,
2
,
10
),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
):
for
num_categories_
in
num_categories
:
yield
make_one_hot_label
(
categories
=
[
f
"category
{
idx
}
"
for
idx
in
range
(
num_categories_
)])
for
extra_dims_
in
extra_dims
:
yield
make_one_hot_label
(
extra_dims_
)
def
make_segmentation_mask
(
size
=
None
,
*
,
num_objects
=
None
,
extra_dims
=
(),
dtype
=
torch
.
uint8
):
size
=
size
if
size
is
not
None
else
torch
.
randint
(
16
,
33
,
(
2
,)).
tolist
()
num_objects
=
num_objects
if
num_objects
is
not
None
else
int
(
torch
.
randint
(
1
,
11
,
()))
shape
=
(
*
extra_dims
,
num_objects
,
*
size
)
data
=
make_tensor
(
shape
,
low
=
0
,
high
=
2
,
dtype
=
dtype
)
return
features
.
SegmentationMask
(
data
)
def
make_segmentation_masks
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
)),
dtypes
=
(
torch
.
uint8
,),
extra_dims
=
((),
(
0
,),
(
4
,),
(
2
,
3
),
(
5
,
0
),
(
0
,
5
)),
num_objects
=
(
1
,
0
,
10
),
):
for
size
,
dtype
,
extra_dims_
in
itertools
.
product
(
sizes
,
dtypes
,
extra_dims
):
yield
make_segmentation_mask
(
size
=
size
,
dtype
=
dtype
,
extra_dims
=
extra_dims_
)
for
dtype
,
extra_dims_
,
num_objects_
in
itertools
.
product
(
dtypes
,
extra_dims
,
num_objects
):
yield
make_segmentation_mask
(
size
=
sizes
[
0
],
num_objects
=
num_objects_
,
dtype
=
dtype
,
extra_dims
=
extra_dims_
)
class
SampleInput
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
args
=
args
self
.
kwargs
=
kwargs
class
FunctionalInfo
:
class
FunctionalInfo
:
...
@@ -182,7 +26,7 @@ class FunctionalInfo:
...
@@ -182,7 +26,7 @@ class FunctionalInfo:
yield
from
self
.
_sample_inputs_fn
()
yield
from
self
.
_sample_inputs_fn
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
==
1
and
not
kwargs
and
isinstance
(
args
[
0
],
SampleInput
):
if
len
(
args
)
==
1
and
not
kwargs
and
isinstance
(
args
[
0
],
ArgsKwargs
):
sample_input
=
args
[
0
]
sample_input
=
args
[
0
]
return
self
.
functional
(
*
sample_input
.
args
,
**
sample_input
.
kwargs
)
return
self
.
functional
(
*
sample_input
.
args
,
**
sample_input
.
kwargs
)
...
@@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
...
@@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
horizontal_flip_image_tensor
():
def
horizontal_flip_image_tensor
():
for
image
in
make_images
():
for
image
in
make_images
():
yield
SampleInput
(
image
)
yield
ArgsKwargs
(
image
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
horizontal_flip_bounding_box
():
def
horizontal_flip_bounding_box
():
for
bounding_box
in
make_bounding_boxes
(
formats
=
[
features
.
BoundingBoxFormat
.
XYXY
]):
for
bounding_box
in
make_bounding_boxes
(
formats
=
[
features
.
BoundingBoxFormat
.
XYXY
]):
yield
SampleInput
(
bounding_box
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
)
yield
ArgsKwargs
(
bounding_box
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
horizontal_flip_segmentation_mask
():
def
horizontal_flip_segmentation_mask
():
for
mask
in
make_segmentation_masks
():
for
mask
in
make_segmentation_masks
():
yield
SampleInput
(
mask
)
yield
ArgsKwargs
(
mask
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
vertical_flip_image_tensor
():
def
vertical_flip_image_tensor
():
for
image
in
make_images
():
for
image
in
make_images
():
yield
SampleInput
(
image
)
yield
ArgsKwargs
(
image
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
vertical_flip_bounding_box
():
def
vertical_flip_bounding_box
():
for
bounding_box
in
make_bounding_boxes
(
formats
=
[
features
.
BoundingBoxFormat
.
XYXY
]):
for
bounding_box
in
make_bounding_boxes
(
formats
=
[
features
.
BoundingBoxFormat
.
XYXY
]):
yield
SampleInput
(
bounding_box
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
)
yield
ArgsKwargs
(
bounding_box
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
vertical_flip_segmentation_mask
():
def
vertical_flip_segmentation_mask
():
for
mask
in
make_segmentation_masks
():
for
mask
in
make_segmentation_masks
():
yield
SampleInput
(
mask
)
yield
ArgsKwargs
(
mask
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -252,7 +96,7 @@ def resize_image_tensor():
...
@@ -252,7 +96,7 @@ def resize_image_tensor():
]:
]:
if
max_size
is
not
None
:
if
max_size
is
not
None
:
size
=
[
size
[
0
]]
size
=
[
size
[
0
]]
yield
SampleInput
(
image
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
yield
ArgsKwargs
(
image
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -268,7 +112,7 @@ def resize_bounding_box():
...
@@ -268,7 +112,7 @@ def resize_bounding_box():
]:
]:
if
max_size
is
not
None
:
if
max_size
is
not
None
:
size
=
[
size
[
0
]]
size
=
[
size
[
0
]]
yield
SampleInput
(
bounding_box
,
size
=
size
,
image_size
=
bounding_box
.
image_size
)
yield
ArgsKwargs
(
bounding_box
,
size
=
size
,
image_size
=
bounding_box
.
image_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -284,7 +128,7 @@ def resize_segmentation_mask():
...
@@ -284,7 +128,7 @@ def resize_segmentation_mask():
]:
]:
if
max_size
is
not
None
:
if
max_size
is
not
None
:
size
=
[
size
[
0
]]
size
=
[
size
[
0
]]
yield
SampleInput
(
mask
,
size
=
size
,
max_size
=
max_size
)
yield
ArgsKwargs
(
mask
,
size
=
size
,
max_size
=
max_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -296,7 +140,7 @@ def affine_image_tensor():
...
@@ -296,7 +140,7 @@ def affine_image_tensor():
[
0.77
,
1.27
],
# scale
[
0.77
,
1.27
],
# scale
[
0
,
12
],
# shear
[
0
,
12
],
# shear
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
image
,
image
,
angle
=
angle
,
angle
=
angle
,
translate
=
(
translate
,
translate
),
translate
=
(
translate
,
translate
),
...
@@ -315,7 +159,7 @@ def affine_bounding_box():
...
@@ -315,7 +159,7 @@ def affine_bounding_box():
[
0.77
,
1.27
],
# scale
[
0.77
,
1.27
],
# scale
[
0
,
12
],
# shear
[
0
,
12
],
# shear
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
bounding_box
,
format
=
bounding_box
.
format
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
,
image_size
=
bounding_box
.
image_size
,
...
@@ -335,7 +179,7 @@ def affine_segmentation_mask():
...
@@ -335,7 +179,7 @@ def affine_segmentation_mask():
[
0.77
,
1.27
],
# scale
[
0.77
,
1.27
],
# scale
[
0
,
12
],
# shear
[
0
,
12
],
# shear
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
mask
,
mask
,
angle
=
angle
,
angle
=
angle
,
translate
=
(
translate
,
translate
),
translate
=
(
translate
,
translate
),
...
@@ -357,7 +201,7 @@ def rotate_image_tensor():
...
@@ -357,7 +201,7 @@ def rotate_image_tensor():
# Skip warning: The provided center argument is ignored if expand is True
# Skip warning: The provided center argument is ignored if expand is True
continue
continue
yield
SampleInput
(
image
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
fill
=
fill
)
yield
ArgsKwargs
(
image
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
fill
=
fill
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -369,7 +213,7 @@ def rotate_bounding_box():
...
@@ -369,7 +213,7 @@ def rotate_bounding_box():
# Skip warning: The provided center argument is ignored if expand is True
# Skip warning: The provided center argument is ignored if expand is True
continue
continue
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
bounding_box
,
format
=
bounding_box
.
format
,
format
=
bounding_box
.
format
,
image_size
=
bounding_box
.
image_size
,
image_size
=
bounding_box
.
image_size
,
...
@@ -391,7 +235,7 @@ def rotate_segmentation_mask():
...
@@ -391,7 +235,7 @@ def rotate_segmentation_mask():
# Skip warning: The provided center argument is ignored if expand is True
# Skip warning: The provided center argument is ignored if expand is True
continue
continue
yield
SampleInput
(
yield
ArgsKwargs
(
mask
,
mask
,
angle
=
angle
,
angle
=
angle
,
expand
=
expand
,
expand
=
expand
,
...
@@ -402,7 +246,7 @@ def rotate_segmentation_mask():
...
@@ -402,7 +246,7 @@ def rotate_segmentation_mask():
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
crop_image_tensor
():
def
crop_image_tensor
():
for
image
,
top
,
left
,
height
,
width
in
itertools
.
product
(
make_images
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
]):
for
image
,
top
,
left
,
height
,
width
in
itertools
.
product
(
make_images
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
]):
yield
SampleInput
(
yield
ArgsKwargs
(
image
,
image
,
top
=
top
,
top
=
top
,
left
=
left
,
left
=
left
,
...
@@ -414,7 +258,7 @@ def crop_image_tensor():
...
@@ -414,7 +258,7 @@ def crop_image_tensor():
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
crop_bounding_box
():
def
crop_bounding_box
():
for
bounding_box
,
top
,
left
in
itertools
.
product
(
make_bounding_boxes
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
]):
for
bounding_box
,
top
,
left
in
itertools
.
product
(
make_bounding_boxes
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
]):
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
bounding_box
,
format
=
bounding_box
.
format
,
format
=
bounding_box
.
format
,
top
=
top
,
top
=
top
,
...
@@ -427,7 +271,7 @@ def crop_segmentation_mask():
...
@@ -427,7 +271,7 @@ def crop_segmentation_mask():
for
mask
,
top
,
left
,
height
,
width
in
itertools
.
product
(
for
mask
,
top
,
left
,
height
,
width
in
itertools
.
product
(
make_segmentation_masks
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
]
make_segmentation_masks
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
]
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
mask
,
mask
,
top
=
top
,
top
=
top
,
left
=
left
,
left
=
left
,
...
@@ -447,7 +291,7 @@ def resized_crop_image_tensor():
...
@@ -447,7 +291,7 @@ def resized_crop_image_tensor():
[(
16
,
18
)],
[(
16
,
18
)],
[
True
,
False
],
[
True
,
False
],
):
):
yield
SampleInput
(
mask
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
,
antialias
=
antialias
)
yield
ArgsKwargs
(
mask
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
,
antialias
=
antialias
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -455,7 +299,7 @@ def resized_crop_bounding_box():
...
@@ -455,7 +299,7 @@ def resized_crop_bounding_box():
for
bounding_box
,
top
,
left
,
height
,
width
,
size
in
itertools
.
product
(
for
bounding_box
,
top
,
left
,
height
,
width
,
size
in
itertools
.
product
(
make_bounding_boxes
(),
[
-
8
,
9
],
[
-
8
,
9
],
[
32
,
22
],
[
34
,
20
],
[(
32
,
32
),
(
16
,
18
)]
make_bounding_boxes
(),
[
-
8
,
9
],
[
-
8
,
9
],
[
32
,
22
],
[
34
,
20
],
[(
32
,
32
),
(
16
,
18
)]
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
format
=
bounding_box
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
bounding_box
,
format
=
bounding_box
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
)
...
@@ -465,7 +309,7 @@ def resized_crop_segmentation_mask():
...
@@ -465,7 +309,7 @@ def resized_crop_segmentation_mask():
for
mask
,
top
,
left
,
height
,
width
,
size
in
itertools
.
product
(
for
mask
,
top
,
left
,
height
,
width
,
size
in
itertools
.
product
(
make_segmentation_masks
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
],
[(
32
,
32
),
(
16
,
18
)]
make_segmentation_masks
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
],
[(
32
,
32
),
(
16
,
18
)]
):
):
yield
SampleInput
(
mask
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
yield
ArgsKwargs
(
mask
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -476,7 +320,7 @@ def pad_image_tensor():
...
@@ -476,7 +320,7 @@ def pad_image_tensor():
[
None
,
12
,
12.0
],
# fill
[
None
,
12
,
12.0
],
# fill
[
"constant"
,
"symmetric"
,
"edge"
,
"reflect"
],
# padding mode,
[
"constant"
,
"symmetric"
,
"edge"
,
"reflect"
],
# padding mode,
):
):
yield
SampleInput
(
image
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
yield
ArgsKwargs
(
image
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -486,7 +330,7 @@ def pad_segmentation_mask():
...
@@ -486,7 +330,7 @@ def pad_segmentation_mask():
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]],
# padding
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]],
# padding
[
"constant"
,
"symmetric"
,
"edge"
,
"reflect"
],
# padding mode,
[
"constant"
,
"symmetric"
,
"edge"
,
"reflect"
],
# padding mode,
):
):
yield
SampleInput
(
mask
,
padding
=
padding
,
padding_mode
=
padding_mode
)
yield
ArgsKwargs
(
mask
,
padding
=
padding
,
padding_mode
=
padding_mode
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -495,7 +339,7 @@ def pad_bounding_box():
...
@@ -495,7 +339,7 @@ def pad_bounding_box():
make_bounding_boxes
(),
make_bounding_boxes
(),
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]],
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]],
):
):
yield
SampleInput
(
bounding_box
,
padding
=
padding
,
format
=
bounding_box
.
format
)
yield
ArgsKwargs
(
bounding_box
,
padding
=
padding
,
format
=
bounding_box
.
format
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -508,7 +352,7 @@ def perspective_image_tensor():
...
@@ -508,7 +352,7 @@ def perspective_image_tensor():
],
],
[
None
,
[
128
],
[
12.0
]],
# fill
[
None
,
[
128
],
[
12.0
]],
# fill
):
):
yield
SampleInput
(
image
,
perspective_coeffs
=
perspective_coeffs
,
fill
=
fill
)
yield
ArgsKwargs
(
image
,
perspective_coeffs
=
perspective_coeffs
,
fill
=
fill
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -520,7 +364,7 @@ def perspective_bounding_box():
...
@@ -520,7 +364,7 @@ def perspective_bounding_box():
[
0.7366
,
-
0.11724
,
1.45775
,
-
0.15012
,
0.73406
,
2.6019
,
-
0.0072
,
-
0.0063
],
[
0.7366
,
-
0.11724
,
1.45775
,
-
0.15012
,
0.73406
,
2.6019
,
-
0.0072
,
-
0.0063
],
],
],
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
bounding_box
,
format
=
bounding_box
.
format
,
format
=
bounding_box
.
format
,
perspective_coeffs
=
perspective_coeffs
,
perspective_coeffs
=
perspective_coeffs
,
...
@@ -536,7 +380,7 @@ def perspective_segmentation_mask():
...
@@ -536,7 +380,7 @@ def perspective_segmentation_mask():
[
0.7366
,
-
0.11724
,
1.45775
,
-
0.15012
,
0.73406
,
2.6019
,
-
0.0072
,
-
0.0063
],
[
0.7366
,
-
0.11724
,
1.45775
,
-
0.15012
,
0.73406
,
2.6019
,
-
0.0072
,
-
0.0063
],
],
],
):
):
yield
SampleInput
(
yield
ArgsKwargs
(
mask
,
mask
,
perspective_coeffs
=
perspective_coeffs
,
perspective_coeffs
=
perspective_coeffs
,
)
)
...
@@ -550,7 +394,7 @@ def elastic_image_tensor():
...
@@ -550,7 +394,7 @@ def elastic_image_tensor():
):
):
h
,
w
=
image
.
shape
[
-
2
:]
h
,
w
=
image
.
shape
[
-
2
:]
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
yield
SampleInput
(
image
,
displacement
=
displacement
,
fill
=
fill
)
yield
ArgsKwargs
(
image
,
displacement
=
displacement
,
fill
=
fill
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -558,7 +402,7 @@ def elastic_bounding_box():
...
@@ -558,7 +402,7 @@ def elastic_bounding_box():
for
bounding_box
in
make_bounding_boxes
():
for
bounding_box
in
make_bounding_boxes
():
h
,
w
=
bounding_box
.
image_size
h
,
w
=
bounding_box
.
image_size
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
bounding_box
,
format
=
bounding_box
.
format
,
format
=
bounding_box
.
format
,
displacement
=
displacement
,
displacement
=
displacement
,
...
@@ -570,7 +414,7 @@ def elastic_segmentation_mask():
...
@@ -570,7 +414,7 @@ def elastic_segmentation_mask():
for
mask
in
make_segmentation_masks
(
extra_dims
=
((),
(
4
,))):
for
mask
in
make_segmentation_masks
(
extra_dims
=
((),
(
4
,))):
h
,
w
=
mask
.
shape
[
-
2
:]
h
,
w
=
mask
.
shape
[
-
2
:]
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
displacement
=
torch
.
rand
(
1
,
h
,
w
,
2
)
yield
SampleInput
(
yield
ArgsKwargs
(
mask
,
mask
,
displacement
=
displacement
,
displacement
=
displacement
,
)
)
...
@@ -582,13 +426,13 @@ def center_crop_image_tensor():
...
@@ -582,13 +426,13 @@ def center_crop_image_tensor():
make_images
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
))),
make_images
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
))),
[[
4
,
3
],
[
42
,
70
],
[
4
]],
# crop sizes < image sizes, crop_sizes > image sizes, single crop size
[[
4
,
3
],
[
42
,
70
],
[
4
]],
# crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
):
yield
SampleInput
(
mask
,
output_size
)
yield
ArgsKwargs
(
mask
,
output_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
center_crop_bounding_box
():
def
center_crop_bounding_box
():
for
bounding_box
,
output_size
in
itertools
.
product
(
make_bounding_boxes
(),
[(
24
,
12
),
[
16
,
18
],
[
46
,
48
],
[
12
]]):
for
bounding_box
,
output_size
in
itertools
.
product
(
make_bounding_boxes
(),
[(
24
,
12
),
[
16
,
18
],
[
46
,
48
],
[
12
]]):
yield
SampleInput
(
yield
ArgsKwargs
(
bounding_box
,
format
=
bounding_box
.
format
,
output_size
=
output_size
,
image_size
=
bounding_box
.
image_size
bounding_box
,
format
=
bounding_box
.
format
,
output_size
=
output_size
,
image_size
=
bounding_box
.
image_size
)
)
...
@@ -599,7 +443,7 @@ def center_crop_segmentation_mask():
...
@@ -599,7 +443,7 @@ def center_crop_segmentation_mask():
make_segmentation_masks
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
))),
make_segmentation_masks
(
sizes
=
((
16
,
16
),
(
7
,
33
),
(
31
,
9
))),
[[
4
,
3
],
[
42
,
70
],
[
4
]],
# crop sizes < image sizes, crop_sizes > image sizes, single crop size
[[
4
,
3
],
[
42
,
70
],
[
4
]],
# crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
):
yield
SampleInput
(
mask
,
output_size
)
yield
ArgsKwargs
(
mask
,
output_size
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -609,7 +453,7 @@ def gaussian_blur_image_tensor():
...
@@ -609,7 +453,7 @@ def gaussian_blur_image_tensor():
[[
3
,
3
]],
[[
3
,
3
]],
[
None
,
[
3.0
,
3.0
]],
[
None
,
[
3.0
,
3.0
]],
):
):
yield
SampleInput
(
image
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
yield
ArgsKwargs
(
image
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -617,13 +461,13 @@ def equalize_image_tensor():
...
@@ -617,13 +461,13 @@ def equalize_image_tensor():
for
image
in
make_images
(
extra_dims
=
(),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
for
image
in
make_images
(
extra_dims
=
(),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
if
image
.
dtype
!=
torch
.
uint8
:
if
image
.
dtype
!=
torch
.
uint8
:
continue
continue
yield
SampleInput
(
image
)
yield
ArgsKwargs
(
image
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
invert_image_tensor
():
def
invert_image_tensor
():
for
image
in
make_images
(
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
for
image
in
make_images
(
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
yield
SampleInput
(
image
)
yield
ArgsKwargs
(
image
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -634,7 +478,7 @@ def posterize_image_tensor():
...
@@ -634,7 +478,7 @@ def posterize_image_tensor():
):
):
if
image
.
dtype
!=
torch
.
uint8
:
if
image
.
dtype
!=
torch
.
uint8
:
continue
continue
yield
SampleInput
(
image
,
bits
=
bits
)
yield
ArgsKwargs
(
image
,
bits
=
bits
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -645,13 +489,13 @@ def solarize_image_tensor():
...
@@ -645,13 +489,13 @@ def solarize_image_tensor():
):
):
if
image
.
is_floating_point
()
and
threshold
>
1.0
:
if
image
.
is_floating_point
()
and
threshold
>
1.0
:
continue
continue
yield
SampleInput
(
image
,
threshold
=
threshold
)
yield
ArgsKwargs
(
image
,
threshold
=
threshold
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
autocontrast_image_tensor
():
def
autocontrast_image_tensor
():
for
image
in
make_images
(
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
for
image
in
make_images
(
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)):
yield
SampleInput
(
image
)
yield
ArgsKwargs
(
image
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
...
@@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor():
...
@@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor():
make_images
(
extra_dims
=
((
4
,),),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)),
make_images
(
extra_dims
=
((
4
,),),
color_spaces
=
(
features
.
ColorSpace
.
GRAY
,
features
.
ColorSpace
.
RGB
)),
[
0.1
,
0.5
],
[
0.1
,
0.5
],
):
):
yield
SampleInput
(
image
,
sharpness_factor
=
sharpness_factor
)
yield
ArgsKwargs
(
image
,
sharpness_factor
=
sharpness_factor
)
@
register_kernel_info_from_sample_inputs_fn
@
register_kernel_info_from_sample_inputs_fn
def
erase_image_tensor
():
def
erase_image_tensor
():
for
image
in
make_images
():
for
image
in
make_images
():
c
=
image
.
shape
[
-
3
]
c
=
image
.
shape
[
-
3
]
yield
SampleInput
(
image
,
i
=
1
,
j
=
2
,
h
=
6
,
w
=
7
,
v
=
torch
.
rand
(
c
,
6
,
7
))
yield
ArgsKwargs
(
image
,
i
=
1
,
j
=
2
,
h
=
6
,
w
=
7
,
v
=
torch
.
rand
(
c
,
6
,
7
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
test/test_prototype_transforms_utils.py
View file @
8acf1ca2
...
@@ -3,7 +3,7 @@ import pytest
...
@@ -3,7 +3,7 @@ import pytest
import
torch
import
torch
from
test_
prototype_
transforms_functional
import
make_bounding_box
,
make_image
,
make_segmentation_mask
from
prototype_
common_utils
import
make_bounding_box
,
make_image
,
make_segmentation_mask
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms._utils
import
has_all
,
has_any
from
torchvision.prototype.transforms._utils
import
has_all
,
has_any
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment