Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d5f4cc38
Unverified
Commit
d5f4cc38
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)
parent
b9447fdd
Changes
85
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
455 additions
and
455 deletions
+455
-455
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+15
-15
test/test_transforms_v2.py
test/test_transforms_v2.py
+57
-57
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+20
-20
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+67
-67
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+100
-100
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+32
-32
test/test_tv_tensors.py
test/test_tv_tensors.py
+52
-52
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+76
-76
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+5
-5
test/transforms_v2_legacy_utils.py
test/transforms_v2_legacy_utils.py
+16
-16
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+1
-1
torchvision/prototype/__init__.py
torchvision/prototype/__init__.py
+1
-1
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+2
-2
torchvision/prototype/datasets/_builtin/celeba.py
torchvision/prototype/datasets/_builtin/celeba.py
+2
-2
torchvision/prototype/datasets/_builtin/cifar.py
torchvision/prototype/datasets/_builtin/cifar.py
+2
-2
torchvision/prototype/datasets/_builtin/clevr.py
torchvision/prototype/datasets/_builtin/clevr.py
+1
-1
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+2
-2
torchvision/prototype/datasets/_builtin/country211.py
torchvision/prototype/datasets/_builtin/country211.py
+1
-1
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+2
-2
torchvision/prototype/datasets/_builtin/dtd.py
torchvision/prototype/datasets/_builtin/dtd.py
+1
-1
No files found.
test/test_prototype_transforms.py
View file @
d5f4cc38
...
@@ -7,11 +7,11 @@ import torch
...
@@ -7,11 +7,11 @@ import torch
from
common_utils
import
assert_equal
from
common_utils
import
assert_equal
from
prototype_common_utils
import
make_label
from
prototype_common_utils
import
make_label
from
torchvision.prototype
import
transforms
,
tv_tensors
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.tv_tensors
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
DEFAULT_EXTRA_DIMS
,
DEFAULT_EXTRA_DIMS
,
make_bounding_boxes
,
make_bounding_boxes
,
...
@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
...
@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
# images, batch size = 2
# images, batch size = 2
self
.
create_fake_image
(
mocker
,
Image
),
self
.
create_fake_image
(
mocker
,
Image
),
# labels, bboxes, masks
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
datapoint
s
.
Label
),
mocker
.
MagicMock
(
spec
=
tv_tensor
s
.
Label
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
# labels, bboxes, masks
...
@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
...
@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
transform
.
_extract_image_targets
(
flat_sample
)
transform
.
_extract_image_targets
(
flat_sample
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
[
Image
,
PIL
.
Image
.
Image
,
torch
.
Tensor
])
@
pytest
.
mark
.
parametrize
(
"image_type"
,
[
Image
,
PIL
.
Image
.
Image
,
torch
.
Tensor
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
datapoints
.
Label
,
datapoint
s
.
OneHotLabel
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensor
s
.
OneHotLabel
])
def
test__extract_image_targets
(
self
,
image_type
,
label_type
,
mocker
):
def
test__extract_image_targets
(
self
,
image_type
,
label_type
,
mocker
):
transform
=
transforms
.
SimpleCopyPaste
()
transform
=
transforms
.
SimpleCopyPaste
()
...
@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
...
@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
assert
isinstance
(
target
[
key
],
type_
)
assert
isinstance
(
target
[
key
],
type_
)
assert
target
[
key
]
in
flat_sample
assert
target
[
key
]
in
flat_sample
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
datapoints
.
Label
,
datapoint
s
.
OneHotLabel
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensor
s
.
OneHotLabel
])
def
test__copy_paste
(
self
,
label_type
):
def
test__copy_paste
(
self
,
label_type
):
image
=
2
*
torch
.
ones
(
3
,
32
,
32
)
image
=
2
*
torch
.
ones
(
3
,
32
,
32
)
masks
=
torch
.
zeros
(
2
,
32
,
32
)
masks
=
torch
.
zeros
(
2
,
32
,
32
)
...
@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
...
@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
blending
=
True
blending
=
True
resize_interpolation
=
InterpolationMode
.
BILINEAR
resize_interpolation
=
InterpolationMode
.
BILINEAR
antialias
=
None
antialias
=
None
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
target
=
{
target
=
{
"boxes"
:
BoundingBoxes
(
"boxes"
:
BoundingBoxes
(
...
@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
...
@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
paste_masks
[
0
,
13
:
19
,
12
:
18
]
=
1
paste_masks
[
0
,
13
:
19
,
12
:
18
]
=
1
paste_masks
[
1
,
15
:
19
,
1
:
8
]
=
1
paste_masks
[
1
,
15
:
19
,
1
:
8
]
=
1
paste_labels
=
torch
.
tensor
([
3
,
4
])
paste_labels
=
torch
.
tensor
([
3
,
4
])
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_target
=
{
paste_target
=
{
"boxes"
:
BoundingBoxes
(
"boxes"
:
BoundingBoxes
(
...
@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
...
@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][
2
:,
:],
paste_target
[
"boxes"
])
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][
2
:,
:],
paste_target
[
"boxes"
])
expected_labels
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
expected_labels
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
expected_labels
=
torch
.
nn
.
functional
.
one_hot
(
expected_labels
,
num_classes
=
5
)
expected_labels
=
torch
.
nn
.
functional
.
one_hot
(
expected_labels
,
num_classes
=
5
)
torch
.
testing
.
assert_close
(
output_target
[
"labels"
],
label_type
(
expected_labels
))
torch
.
testing
.
assert_close
(
output_target
[
"labels"
],
label_type
(
expected_labels
))
...
@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
...
@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
class
TestLabelToOneHot
:
class
TestLabelToOneHot
:
def
test__transform
(
self
):
def
test__transform
(
self
):
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
labels
=
datapoint
s
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
labels
=
tv_tensor
s
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
transform
=
transforms
.
LabelToOneHot
()
transform
=
transforms
.
LabelToOneHot
()
ohe_labels
=
transform
(
labels
)
ohe_labels
=
transform
(
labels
)
assert
isinstance
(
ohe_labels
,
datapoint
s
.
OneHotLabel
)
assert
isinstance
(
ohe_labels
,
tv_tensor
s
.
OneHotLabel
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
...
@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
...
@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
def
test_fixed_sized_crop_against_detection_reference
():
def
test_fixed_sized_crop_against_detection_reference
():
def
make_
datapoint
s
():
def
make_
tv_tensor
s
():
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
...
@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
...
@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
)
target
=
{
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"masks"
:
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
),
"masks"
:
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
),
}
}
yield
(
datapoint
_image
,
target
)
yield
(
tv_tensor
_image
,
target
)
t
=
transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t
=
transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t_ref
=
det_transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t_ref
=
det_transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
for
dp
in
make_
datapoint
s
():
for
dp
in
make_
tv_tensor
s
():
# We should use prototype transform first as reference transform performs inplace target update
# We should use prototype transform first as reference transform performs inplace target update
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
output
=
t
(
dp
)
output
=
t
(
dp
)
...
...
test/test_transforms_v2.py
View file @
d5f4cc38
...
@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
...
@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
...
@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
...
@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
adapted_input
=
{}
adapted_input
=
{}
image_or_video_found
=
False
image_or_video_found
=
False
for
key
,
value
in
input
.
items
():
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)):
if
isinstance
(
value
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)):
# AA transforms don't support bounding boxes or masks
# AA transforms don't support bounding boxes or masks
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoint
s
.
Video
,
is_pure_tensor
,
PIL
.
Image
.
Image
)):
elif
check_type
(
value
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
,
is_pure_tensor
,
PIL
.
Image
.
Image
)):
if
image_or_video_found
:
if
image_or_video_found
:
# AA transforms only support a single image or video
# AA transforms only support a single image or video
continue
continue
...
@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
...
@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
# normalize doesn't support PIL images
# normalize doesn't support PIL images
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoint
s
.
Video
,
is_pure_tensor
)):
elif
check_type
(
value
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
,
is_pure_tensor
)):
# normalize doesn't support integer images
# normalize doesn't support integer images
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
adapted_input
[
key
]
=
value
adapted_input
[
key
]
=
value
...
@@ -142,7 +142,7 @@ class TestSmoke:
...
@@ -142,7 +142,7 @@ class TestSmoke:
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertImageDtype
(),
None
),
(
transforms
.
ConvertImageDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
(
...
@@ -178,19 +178,19 @@ class TestSmoke:
...
@@ -178,19 +178,19 @@ class TestSmoke:
canvas_size
=
F
.
get_size
(
image_or_video
)
canvas_size
=
F
.
get_size
(
image_or_video
)
input
=
dict
(
input
=
dict
(
image_or_video
=
image_or_video
,
image_or_video
=
image_or_video
,
image_
datapoint
=
make_image
(
size
=
canvas_size
),
image_
tv_tensor
=
make_image
(
size
=
canvas_size
),
video_
datapoint
=
make_video
(
size
=
canvas_size
),
video_
tv_tensor
=
make_video
(
size
=
canvas_size
),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
canvas_size
],
color_spaces
=
[
"RGB"
])),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
canvas_size
],
color_spaces
=
[
"RGB"
])),
bounding_boxes_xyxy
=
make_bounding_boxes
(
bounding_boxes_xyxy
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
3
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
3
,)
),
),
bounding_boxes_xywh
=
make_bounding_boxes
(
bounding_boxes_xywh
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
4
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
4
,)
),
),
bounding_boxes_cxcywh
=
make_bounding_boxes
(
bounding_boxes_cxcywh
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
5
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
5
,)
),
),
bounding_boxes_degenerate_xyxy
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_xyxy
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -199,10 +199,10 @@ class TestSmoke:
...
@@ -199,10 +199,10 @@ class TestSmoke:
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
bounding_boxes_degenerate_xywh
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_xywh
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -211,10 +211,10 @@ class TestSmoke:
...
@@ -211,10 +211,10 @@ class TestSmoke:
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
bounding_boxes_degenerate_cxcywh
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_cxcywh
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -223,7 +223,7 @@ class TestSmoke:
...
@@ -223,7 +223,7 @@ class TestSmoke:
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
detection_mask
=
make_detection_mask
(
size
=
canvas_size
),
detection_mask
=
make_detection_mask
(
size
=
canvas_size
),
...
@@ -262,7 +262,7 @@ class TestSmoke:
...
@@ -262,7 +262,7 @@ class TestSmoke:
else
:
else
:
assert
output_item
is
input_item
assert
output_item
is
input_item
if
isinstance
(
input_item
,
datapoint
s
.
BoundingBoxes
)
and
not
isinstance
(
if
isinstance
(
input_item
,
tv_tensor
s
.
BoundingBoxes
)
and
not
isinstance
(
transform
,
transforms
.
ConvertBoundingBoxFormat
transform
,
transforms
.
ConvertBoundingBoxFormat
):
):
assert
output_item
.
format
==
input_item
.
format
assert
output_item
.
format
==
input_item
.
format
...
@@ -270,9 +270,9 @@ class TestSmoke:
...
@@ -270,9 +270,9 @@ class TestSmoke:
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one.
# transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above
# TODO: we should test that against all degenerate boxes above
for
format
in
list
(
datapoint
s
.
BoundingBoxFormat
):
for
format
in
list
(
tv_tensor
s
.
BoundingBoxFormat
):
sample
=
dict
(
sample
=
dict
(
boxes
=
datapoint
s
.
BoundingBoxes
([[
0
,
0
,
0
,
0
]],
format
=
format
,
canvas_size
=
(
224
,
244
)),
boxes
=
tv_tensor
s
.
BoundingBoxes
([[
0
,
0
,
0
,
0
]],
format
=
format
,
canvas_size
=
(
224
,
244
)),
labels
=
torch
.
tensor
([
3
]),
labels
=
torch
.
tensor
([
3
]),
)
)
assert
transforms
.
SanitizeBoundingBoxes
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
assert
transforms
.
SanitizeBoundingBoxes
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
...
@@ -652,7 +652,7 @@ class TestRandomErasing:
...
@@ -652,7 +652,7 @@ class TestRandomErasing:
class
TestTransform
:
class
TestTransform
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test_check_transformed_types
(
self
,
inpt_type
,
mocker
):
def
test_check_transformed_types
(
self
,
inpt_type
,
mocker
):
# This test ensures that we correctly handle which types to transform and which to bypass
# This test ensures that we correctly handle which types to transform and which to bypass
...
@@ -670,7 +670,7 @@ class TestTransform:
...
@@ -670,7 +670,7 @@ class TestTransform:
class
TestToImage
:
class
TestToImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
fn
=
mocker
.
patch
(
...
@@ -681,7 +681,7 @@ class TestToImage:
...
@@ -681,7 +681,7 @@ class TestToImage:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImage
()
transform
=
transforms
.
ToImage
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Image
,
str
,
int
):
if
inpt_type
in
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
...
@@ -690,7 +690,7 @@ class TestToImage:
...
@@ -690,7 +690,7 @@ class TestToImage:
class
TestToPILImage
:
class
TestToPILImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_pil_image"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_pil_image"
)
...
@@ -698,7 +698,7 @@ class TestToPILImage:
...
@@ -698,7 +698,7 @@ class TestToPILImage:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToPILImage
()
transform
=
transforms
.
ToPILImage
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
PIL
.
Image
.
Image
,
datapoint
s
.
BoundingBoxes
,
str
,
int
):
if
inpt_type
in
(
PIL
.
Image
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
...
@@ -707,7 +707,7 @@ class TestToPILImage:
...
@@ -707,7 +707,7 @@ class TestToPILImage:
class
TestToTensor
:
class
TestToTensor
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.functional.to_tensor"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.functional.to_tensor"
)
...
@@ -716,7 +716,7 @@ class TestToTensor:
...
@@ -716,7 +716,7 @@ class TestToTensor:
with
pytest
.
warns
(
UserWarning
,
match
=
"deprecated and will be removed"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"deprecated and will be removed"
):
transform
=
transforms
.
ToTensor
()
transform
=
transforms
.
ToTensor
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
datapoint
s
.
Image
,
torch
.
Tensor
,
datapoint
s
.
BoundingBoxes
,
str
,
int
):
if
inpt_type
in
(
tv_tensor
s
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
...
@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
...
@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
def
test__get_params
(
self
,
device
,
options
):
def
test__get_params
(
self
,
device
,
options
):
orig_h
,
orig_w
=
size
=
(
24
,
32
)
orig_h
,
orig_w
=
size
=
(
24
,
32
)
image
=
make_image
(
size
)
image
=
make_image
(
size
)
bboxes
=
datapoint
s
.
BoundingBoxes
(
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
10
,
10
],
[
20
,
20
,
23
,
23
],
[
1
,
20
,
10
,
23
],
[
20
,
1
,
23
,
10
]]),
torch
.
tensor
([[
1
,
1
,
10
,
10
],
[
20
,
20
,
23
,
23
],
[
1
,
20
,
10
,
23
],
[
20
,
1
,
23
,
10
]]),
format
=
"XYXY"
,
format
=
"XYXY"
,
canvas_size
=
size
,
canvas_size
=
size
,
...
@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
...
@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
def
test__transform_empty_params
(
self
,
mocker
):
def
test__transform_empty_params
(
self
,
mocker
):
transform
=
transforms
.
RandomIoUCrop
(
sampler_options
=
[
2.0
])
transform
=
transforms
.
RandomIoUCrop
(
sampler_options
=
[
2.0
])
image
=
datapoint
s
.
Image
(
torch
.
rand
(
1
,
3
,
4
,
4
))
image
=
tv_tensor
s
.
Image
(
torch
.
rand
(
1
,
3
,
4
,
4
))
bboxes
=
datapoint
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
canvas_size
=
(
4
,
4
))
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
canvas_size
=
(
4
,
4
))
label
=
torch
.
tensor
([
1
])
label
=
torch
.
tensor
([
1
])
sample
=
[
image
,
bboxes
,
label
]
sample
=
[
image
,
bboxes
,
label
]
# Let's mock transform._get_params to control the output:
# Let's mock transform._get_params to control the output:
...
@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
...
@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels:
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
output_bboxes
,
tv_tensor
s
.
BoundingBoxes
)
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
output_masks
=
output
[
2
]
output_masks
=
output
[
2
]
assert
isinstance
(
output_masks
,
datapoint
s
.
Mask
)
assert
isinstance
(
output_masks
,
tv_tensor
s
.
Mask
)
class
TestScaleJitter
:
class
TestScaleJitter
:
...
@@ -899,7 +899,7 @@ class TestLinearTransformation:
...
@@ -899,7 +899,7 @@ class TestLinearTransformation:
[
[
122
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122.0
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122.0
*
torch
.
ones
(
1
,
3
,
8
,
8
),
datapoint
s
.
Image
(
122
*
torch
.
ones
(
1
,
3
,
8
,
8
)),
tv_tensor
s
.
Image
(
122
*
torch
.
ones
(
1
,
3
,
8
,
8
)),
PIL
.
Image
.
new
(
"RGB"
,
(
8
,
8
),
(
122
,
122
,
122
)),
PIL
.
Image
.
new
(
"RGB"
,
(
8
,
8
),
(
122
,
122
,
122
)),
],
],
)
)
...
@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
...
@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
[
[
torch
.
zeros
(
10
,
3
,
8
,
8
),
torch
.
zeros
(
10
,
3
,
8
,
8
),
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
),
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
),
datapoint
s
.
Video
(
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
)),
tv_tensor
s
.
Video
(
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
)),
],
],
)
)
def
test__transform
(
self
,
inpt
):
def
test__transform
(
self
,
inpt
):
...
@@ -971,12 +971,12 @@ def test_antialias_warning():
...
@@ -971,12 +971,12 @@ def test_antialias_warning():
transforms
.
RandomResize
(
10
,
20
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
)(
tensor_img
)
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resized_crop
(
datapoint
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
tv_tensor
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resize
(
datapoint
s
.
Video
(
tensor_video
),
(
20
,
20
))
F
.
resize
(
tv_tensor
s
.
Video
(
tensor_video
),
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resized_crop
(
datapoint
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
tv_tensor
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
))
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
warnings
.
simplefilter
(
"error"
)
...
@@ -990,17 +990,17 @@ def test_antialias_warning():
...
@@ -990,17 +990,17 @@ def test_antialias_warning():
transforms
.
RandomShortestSize
((
20
,
20
),
antialias
=
True
)(
tensor_img
)
transforms
.
RandomShortestSize
((
20
,
20
),
antialias
=
True
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
,
antialias
=
True
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
,
antialias
=
True
)(
tensor_img
)
F
.
resized_crop
(
datapoint
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
tv_tensor
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
datapoint
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
tv_tensor
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoint
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
image
=
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
image
=
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
...
@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
...
@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
assert
out_label
==
label
assert
out_label
==
label
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoint
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
...
@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# leaving FixedSizeCrop in prototype for now, and it expects Label
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {
datapoint
s.Mask: 0})
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {
tv_tensor
s.Mask: 0})
# ),
# ),
transforms
.
RandomCrop
((
1024
,
1024
),
pad_if_needed
=
True
),
transforms
.
RandomCrop
((
1024
,
1024
),
pad_if_needed
=
True
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
...
@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif
data_augmentation
==
"ssd"
:
elif
data_augmentation
==
"ssd"
:
t
=
[
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
datapoint
s
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
tv_tensor
s
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
,
to_tensor
,
...
@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
num_boxes
=
5
num_boxes
=
5
H
=
W
=
250
H
=
W
=
250
image
=
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
))
image
=
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
...
@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
datapoint
s
.
BoundingBoxes
(
boxes
,
format
=
"XYXY"
,
canvas_size
=
(
H
,
W
))
boxes
=
tv_tensor
s
.
BoundingBoxes
(
boxes
,
format
=
"XYXY"
,
canvas_size
=
(
H
,
W
))
masks
=
datapoint
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
masks
=
tv_tensor
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
sample
=
{
sample
=
{
"image"
:
image
,
"image"
:
image
,
...
@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out
=
t
(
sample
)
out
=
t
(
sample
)
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
datapoint
s
.
Image
:
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
tv_tensor
s
.
Image
:
assert
is_pure_tensor
(
out
[
"image"
])
assert
is_pure_tensor
(
out
[
"image"
])
else
:
else
:
assert
isinstance
(
out
[
"image"
],
datapoint
s
.
Image
)
assert
isinstance
(
out
[
"image"
],
tv_tensor
s
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
num_boxes_expected
=
{
num_boxes_expected
=
{
...
@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...
@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes
=
torch
.
tensor
(
boxes
)
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
boxes
=
datapoint
s
.
BoundingBoxes
(
boxes
=
tv_tensor
s
.
BoundingBoxes
(
boxes
,
boxes
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
H
,
W
),
canvas_size
=
(
H
,
W
),
)
)
masks
=
datapoint
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
masks
=
tv_tensor
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
whatever
=
torch
.
rand
(
10
)
whatever
=
torch
.
rand
(
10
)
input_img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
input_img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
sample
=
{
sample
=
{
...
@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...
@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert
out_image
is
input_img
assert
out_image
is
input_img
assert
out_whatever
is
whatever
assert
out_whatever
is
whatever
assert
isinstance
(
out_boxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
out_boxes
,
tv_tensor
s
.
BoundingBoxes
)
assert
isinstance
(
out_masks
,
datapoint
s
.
Mask
)
assert
isinstance
(
out_masks
,
tv_tensor
s
.
Mask
)
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
assert
out_labels
is
labels
assert
out_labels
is
labels
...
@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
...
@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
transforms
.
SanitizeBoundingBoxes
()(
img
,
boxes
)
transforms
.
SanitizeBoundingBoxes
()(
img
,
boxes
)
out_img
,
out_boxes
=
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
None
)(
img
,
boxes
)
out_img
,
out_boxes
=
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
None
)(
img
,
boxes
)
assert
isinstance
(
out_img
,
datapoint
s
.
Image
)
assert
isinstance
(
out_img
,
tv_tensor
s
.
Image
)
assert
isinstance
(
out_boxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
out_boxes
,
tv_tensor
s
.
BoundingBoxes
)
def
test_sanitize_bounding_boxes_errors
():
def
test_sanitize_bounding_boxes_errors
():
good_bbox
=
datapoint
s
.
BoundingBoxes
(
good_bbox
=
tv_tensor
s
.
BoundingBoxes
(
[[
0
,
0
,
10
,
10
]],
[[
0
,
0
,
10
,
10
]],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
20
,
20
),
canvas_size
=
(
20
,
20
),
)
)
...
...
test/test_transforms_v2_consistency.py
View file @
d5f4cc38
...
@@ -13,7 +13,7 @@ import torch
...
@@ -13,7 +13,7 @@ import torch
import
torchvision.transforms.v2
as
v2_transforms
import
torchvision.transforms.v2
as
v2_transforms
from
common_utils
import
assert_close
,
assert_equal
,
set_rng_seed
from
common_utils
import
assert_close
,
assert_equal
,
set_rng_seed
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
datapoints
,
transforms
as
legacy_transforms
from
torchvision
import
transforms
as
legacy_transforms
,
tv_tensors
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
...
@@ -478,15 +478,15 @@ def check_call_consistency(
...
@@ -478,15 +478,15 @@ def check_call_consistency(
output_prototype_image
=
prototype_transform
(
image
)
output_prototype_image
=
prototype_transform
(
image
)
except
Exception
as
exc
:
except
Exception
as
exc
:
raise
AssertionError
(
raise
AssertionError
(
f
"Transforming a image
datapoint
with shape
{
image_repr
}
failed in the prototype transform with "
f
"Transforming a image
tv_tensor
with shape
{
image_repr
}
failed in the prototype transform with "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"`
datapoint
s.Image` path in `_transform`."
f
"`
tv_tensor
s.Image` path in `_transform`."
)
from
exc
)
from
exc
assert_close
(
assert_close
(
output_prototype_image
,
output_prototype_image
,
output_prototype_tensor
,
output_prototype_tensor
,
msg
=
lambda
msg
:
f
"Output for
datapoint
and tensor images is not equal:
\n\n
{
msg
}
"
,
msg
=
lambda
msg
:
f
"Output for
tv_tensor
and tensor images is not equal:
\n\n
{
msg
}
"
,
**
closeness_kwargs
,
**
closeness_kwargs
,
)
)
...
@@ -747,7 +747,7 @@ class TestAATransforms:
...
@@ -747,7 +747,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -812,7 +812,7 @@ class TestAATransforms:
...
@@ -812,7 +812,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -887,7 +887,7 @@ class TestAATransforms:
...
@@ -887,7 +887,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -964,7 +964,7 @@ class TestAATransforms:
...
@@ -964,7 +964,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
...
@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
class
TestRefDetTransforms
:
class
TestRefDetTransforms
:
def
make_
datapoint
s
(
self
,
with_mask
=
True
):
def
make_
tv_tensor
s
(
self
,
with_mask
=
True
):
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
...
@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
...
@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
torch
.
float32
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
torch
.
float32
)
target
=
{
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
...
@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
if
with_mask
:
if
with_mask
:
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
yield
(
datapoint
_image
,
target
)
yield
(
tv_tensor
_image
,
target
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"t_ref, t, data_kwargs"
,
"t_ref, t, data_kwargs"
,
...
@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
...
@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
],
],
)
)
def
test_transform
(
self
,
t_ref
,
t
,
data_kwargs
):
def
test_transform
(
self
,
t_ref
,
t
,
data_kwargs
):
for
dp
in
self
.
make_
datapoint
s
(
**
data_kwargs
):
for
dp
in
self
.
make_
tv_tensor
s
(
**
data_kwargs
):
# We should use prototype transform first as reference transform performs inplace target update
# We should use prototype transform first as reference transform performs inplace target update
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
...
@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
...
@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
class
TestRefSegTransforms
:
class
TestRefSegTransforms
:
def
make_
datapoint
s
(
self
,
supports_pil
=
True
,
image_dtype
=
torch
.
uint8
):
def
make_
tv_tensor
s
(
self
,
supports_pil
=
True
,
image_dtype
=
torch
.
uint8
):
size
=
(
256
,
460
)
size
=
(
256
,
460
)
num_categories
=
21
num_categories
=
21
...
@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
...
@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
image_dtype
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
image_dtype
)
datapoint
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
tv_tensor
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint
_mask
)
dp
=
(
conv_fn
(
tv_tensor_image
),
tv_tensor
_mask
)
dp_ref
=
(
dp_ref
=
(
to_pil_image
(
datapoint
_image
)
if
supports_pil
else
datapoint
_image
.
as_subclass
(
torch
.
Tensor
),
to_pil_image
(
tv_tensor
_image
)
if
supports_pil
else
tv_tensor
_image
.
as_subclass
(
torch
.
Tensor
),
to_pil_image
(
datapoint
_mask
),
to_pil_image
(
tv_tensor
_mask
),
)
)
yield
dp
,
dp_ref
yield
dp
,
dp_ref
...
@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
...
@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
def
check
(
self
,
t
,
t_ref
,
data_kwargs
=
None
):
def
check
(
self
,
t
,
t_ref
,
data_kwargs
=
None
):
for
dp
,
dp_ref
in
self
.
make_
datapoint
s
(
**
data_kwargs
or
dict
()):
for
dp
,
dp_ref
in
self
.
make_
tv_tensor
s
(
**
data_kwargs
or
dict
()):
self
.
set_seed
()
self
.
set_seed
()
actual
=
actual_image
,
actual_mask
=
t
(
dp
)
actual
=
actual_image
,
actual_mask
=
t
(
dp
)
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
seg_transforms
.
RandomCrop
(
size
=
480
),
seg_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
Compose
(
v2_transforms
.
Compose
(
[
[
PadIfSmaller
(
size
=
480
,
fill
=
{
datapoint
s
.
Mask
:
255
,
"others"
:
0
}),
PadIfSmaller
(
size
=
480
,
fill
=
{
tv_tensor
s
.
Mask
:
255
,
"others"
:
0
}),
v2_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
RandomCrop
(
size
=
480
),
]
]
),
),
...
...
test/test_transforms_v2_functional.py
View file @
d5f4cc38
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
from
common_utils
import
assert_close
,
cache
,
cpu_and_cuda
,
needs_cuda
,
set_rng_seed
from
common_utils
import
assert_close
,
cache
,
cpu_and_cuda
,
needs_cuda
,
set_rng_seed
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2._utils
import
is_pure_tensor
from
torchvision.transforms.v2._utils
import
is_pure_tensor
...
@@ -164,22 +164,22 @@ class TestKernels:
...
@@ -164,22 +164,22 @@ class TestKernels:
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
datapoint_type
=
datapoint
s
.
Image
if
is_pure_tensor
(
batched_input
)
else
type
(
batched_input
)
tv_tensor_type
=
tv_tensor
s
.
Image
if
is_pure_tensor
(
batched_input
)
else
type
(
batched_input
)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
# Everything to the left is considered a batch dimension.
data_dims
=
{
data_dims
=
{
datapoint
s
.
Image
:
3
,
tv_tensor
s
.
Image
:
3
,
datapoint
s
.
BoundingBoxes
:
1
,
tv_tensor
s
.
BoundingBoxes
:
1
,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
# common ground.
datapoint
s
.
Mask
:
2
,
tv_tensor
s
.
Mask
:
2
,
datapoint
s
.
Video
:
4
,
tv_tensor
s
.
Video
:
4
,
}.
get
(
datapoint
_type
)
}.
get
(
tv_tensor
_type
)
if
data_dims
is
None
:
if
data_dims
is
None
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"The number of data dimensions cannot be determined for input of type
{
datapoint
_type
.
__name__
}
."
f
"The number of data dimensions cannot be determined for input of type
{
tv_tensor
_type
.
__name__
}
."
)
from
None
)
from
None
elif
batched_input
.
ndim
<=
data_dims
:
elif
batched_input
.
ndim
<=
data_dims
:
pytest
.
skip
(
"Input is not batched."
)
pytest
.
skip
(
"Input is not batched."
)
...
@@ -305,8 +305,8 @@ def spy_on(mocker):
...
@@ -305,8 +305,8 @@ def spy_on(mocker):
class
TestDispatchers
:
class
TestDispatchers
:
image_sample_inputs
=
make_info_args_kwargs_parametrization
(
image_sample_inputs
=
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
datapoint
s
.
Image
in
info
.
kernels
],
[
info
for
info
in
DISPATCHER_INFOS
if
tv_tensor
s
.
Image
in
info
.
kernels
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
Image
),
)
)
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
...
@@ -328,8 +328,8 @@ class TestDispatchers:
...
@@ -328,8 +328,8 @@ class TestDispatchers:
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
dispatcher
=
script
(
info
.
dispatcher
)
dispatcher
=
script
(
info
.
dispatcher
)
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
image_pure_tensor
=
torch
.
Tensor
(
image_
datapoint
)
image_pure_tensor
=
torch
.
Tensor
(
image_
tv_tensor
)
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
...
@@ -355,25 +355,25 @@ class TestDispatchers:
...
@@ -355,25 +355,25 @@ class TestDispatchers:
@
image_sample_inputs
@
image_sample_inputs
def
test_pure_tensor_output_type
(
self
,
info
,
args_kwargs
):
def
test_pure_tensor_output_type
(
self
,
info
,
args_kwargs
):
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_pure_tensor
=
image_
datapoint
.
as_subclass
(
torch
.
Tensor
)
image_pure_tensor
=
image_
tv_tensor
.
as_subclass
(
torch
.
Tensor
)
output
=
info
.
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
# We cannot use `isinstance` here since all
datapoint
s are instances of `torch.Tensor` as well
# We cannot use `isinstance` here since all
tv_tensor
s are instances of `torch.Tensor` as well
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
output
)
is
torch
.
Tensor
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
Image
),
)
)
def
test_pil_output_type
(
self
,
info
,
args_kwargs
):
def
test_pil_output_type
(
self
,
info
,
args_kwargs
):
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_
datapoint
.
ndim
>
3
:
if
image_
tv_tensor
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_pil_image
(
image_
datapoint
)
image_pil
=
F
.
to_pil_image
(
image_
tv_tensor
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
...
@@ -383,38 +383,38 @@ class TestDispatchers:
...
@@ -383,38 +383,38 @@ class TestDispatchers:
DISPATCHER_INFOS
,
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
)
def
test_
datapoint
_output_type
(
self
,
info
,
args_kwargs
):
def
test_
tv_tensor
_output_type
(
self
,
info
,
args_kwargs
):
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
output
=
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
tv_tensor
,
*
other_args
,
**
kwargs
)
assert
isinstance
(
output
,
type
(
datapoint
))
assert
isinstance
(
output
,
type
(
tv_tensor
))
if
isinstance
(
datapoint
,
datapoint
s
.
BoundingBoxes
)
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
:
if
isinstance
(
tv_tensor
,
tv_tensor
s
.
BoundingBoxes
)
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
:
assert
output
.
format
==
datapoint
.
format
assert
output
.
format
==
tv_tensor
.
format
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"dispatcher_info"
,
"
datapoint
_type"
,
"kernel_info"
),
(
"dispatcher_info"
,
"
tv_tensor
_type"
,
"kernel_info"
),
[
[
pytest
.
param
(
pytest
.
param
(
dispatcher_info
,
datapoint
_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
datapoint
_type
.
__name__
}
"
dispatcher_info
,
tv_tensor
_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
tv_tensor
_type
.
__name__
}
"
)
)
for
dispatcher_info
in
DISPATCHER_INFOS
for
dispatcher_info
in
DISPATCHER_INFOS
for
datapoint
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
for
tv_tensor
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
],
],
)
)
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
datapoint
_type
,
kernel_info
):
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
tv_tensor
_type
,
kernel_info
):
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
# We filter out metadata that is implicitly passed to the dispatcher through the input
datapoint
, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input
tv_tensor
, but has to be
# explicitly passed to the kernel.
# explicitly passed to the kernel.
input_type
=
{
v
:
k
for
k
,
v
in
dispatcher_info
.
kernels
.
items
()}.
get
(
kernel_info
.
kernel
)
input_type
=
{
v
:
k
for
k
,
v
in
dispatcher_info
.
kernels
.
items
()}.
get
(
kernel_info
.
kernel
)
explicit_metadata
=
{
explicit_metadata
=
{
datapoint
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
tv_tensor
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
...
@@ -445,9 +445,9 @@ class TestDispatchers:
...
@@ -445,9 +445,9 @@ class TestDispatchers:
[
[
info
info
for
info
in
DISPATCHER_INFOS
for
info
in
DISPATCHER_INFOS
if
datapoint
s
.
BoundingBoxes
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
if
tv_tensor
s
.
BoundingBoxes
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
],
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
BoundingBoxes
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
BoundingBoxes
),
)
)
def
test_bounding_boxes_format_consistency
(
self
,
info
,
args_kwargs
):
def
test_bounding_boxes_format_consistency
(
self
,
info
,
args_kwargs
):
(
bounding_boxes
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
bounding_boxes
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
...
@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
...
@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
"metadata"
,
"metadata"
,
[
[
dict
(),
dict
(),
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
canvas_size
=
(
1
,
1
)),
],
],
)
)
...
@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
...
@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"metadata"
,
"metadata"
,
[
[
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
1
,
1
)),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
1
,
1
)),
],
],
)
)
def
test_
datapoint
_explicit_metadata
(
self
,
metadata
):
def
test_
tv_tensor
_explicit_metadata
(
self
,
metadata
):
datapoint
=
next
(
make_multiple_bounding_boxes
())
tv_tensor
=
next
(
make_multiple_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` must not be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` must not be passed"
)):
F
.
clamp_bounding_boxes
(
datapoint
,
**
metadata
)
F
.
clamp_bounding_boxes
(
tv_tensor
,
**
metadata
)
class
TestConvertFormatBoundingBoxes
:
class
TestConvertFormatBoundingBoxes
:
...
@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
...
@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
(
"inpt"
,
"old_format"
),
(
"inpt"
,
"old_format"
),
[
[
(
next
(
make_multiple_bounding_boxes
()),
None
),
(
next
(
make_multiple_bounding_boxes
()),
None
),
(
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
),
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
(
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
),
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
],
],
)
)
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
...
@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
...
@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
pure_tensor
=
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
pure_tensor
=
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
F
.
convert_bounding_box_format
(
pure_tensor
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
)
F
.
convert_bounding_box_format
(
pure_tensor
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
)
def
test_
datapoint
_explicit_metadata
(
self
):
def
test_
tv_tensor
_explicit_metadata
(
self
):
datapoint
=
next
(
make_multiple_bounding_boxes
())
tv_tensor
=
next
(
make_multiple_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
F
.
convert_bounding_box_format
(
F
.
convert_bounding_box_format
(
datapoint
,
old_format
=
datapoint
.
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
tv_tensor
,
old_format
=
tv_tensor
.
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
)
)
...
@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
...
@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
"format"
,
[
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
],
[
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"top, left, height, width, expected_bboxes"
,
"top, left, height, width, expected_bboxes"
,
...
@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
# out_box = denormalize_bbox(n_out_box, height, width)
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)
# expected_bboxes.append(out_box)
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
canvas_size
=
(
64
,
76
)
canvas_size
=
(
64
,
76
)
in_boxes
=
[
in_boxes
=
[
[
10.0
,
15.0
,
25.0
,
35.0
],
[
10.0
,
15.0
,
25.0
,
35.0
],
...
@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
[
45.0
,
46.0
,
56.0
,
62.0
],
[
45.0
,
46.0
,
56.0
,
62.0
],
]
]
in_boxes
=
torch
.
tensor
(
in_boxes
,
device
=
device
)
in_boxes
=
torch
.
tensor
(
in_boxes
,
device
=
device
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
format
)
expected_bboxes
=
clamp_bounding_boxes
(
expected_bboxes
=
clamp_bounding_boxes
(
datapoint
s
.
BoundingBoxes
(
expected_bboxes
,
format
=
"XYXY"
,
canvas_size
=
canvas_size
)
tv_tensor
s
.
BoundingBoxes
(
expected_bboxes
,
format
=
"XYXY"
,
canvas_size
=
canvas_size
)
).
tolist
()
).
tolist
()
output_boxes
,
output_canvas_size
=
F
.
crop_bounding_boxes
(
output_boxes
,
output_canvas_size
=
F
.
crop_bounding_boxes
(
...
@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
canvas_size
[
1
],
canvas_size
[
1
],
)
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
canvas_size
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
canvas_size
)
...
@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
...
@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
"format"
,
[
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
],
[
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"top, left, height, width, size"
,
"top, left, height, width, size"
,
...
@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
...
@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
bbox
[
3
]
=
(
bbox
[
3
]
-
top_
)
*
size_
[
0
]
/
height_
bbox
[
3
]
=
(
bbox
[
3
]
-
top_
)
*
size_
[
0
]
/
height_
return
bbox
return
bbox
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
canvas_size
=
(
100
,
100
)
canvas_size
=
(
100
,
100
)
in_boxes
=
[
in_boxes
=
[
[
10.0
,
10.0
,
20.0
,
20.0
],
[
10.0
,
10.0
,
20.0
,
20.0
],
...
@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
...
@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
expected_bboxes
.
append
(
_compute_expected_bbox
(
list
(
in_box
),
top
,
left
,
height
,
width
,
size
))
expected_bboxes
.
append
(
_compute_expected_bbox
(
list
(
in_box
),
top
,
left
,
height
,
width
,
size
))
expected_bboxes
=
torch
.
tensor
(
expected_bboxes
,
device
=
device
)
expected_bboxes
=
torch
.
tensor
(
expected_bboxes
,
device
=
device
)
in_boxes
=
datapoint
s
.
BoundingBoxes
(
in_boxes
=
tv_tensor
s
.
BoundingBoxes
(
in_boxes
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
device
=
device
in_boxes
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
device
=
device
)
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
format
)
output_boxes
,
output_canvas_size
=
F
.
resized_crop_bounding_boxes
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
output_boxes
,
output_canvas_size
=
F
.
resized_crop_bounding_boxes
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
size
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
size
)
...
@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
...
@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
bbox
=
(
bbox
=
(
bbox
.
clone
()
bbox
.
clone
()
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
else
convert_bounding_box_format
(
bbox
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
else
convert_bounding_box_format
(
bbox
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
)
)
bbox
[
0
::
2
]
+=
pad_left
bbox
[
0
::
2
]
+=
pad_left
bbox
[
1
::
2
]
+=
pad_up
bbox
[
1
::
2
]
+=
pad_up
bbox
=
convert_bounding_box_format
(
bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
bbox
=
convert_bounding_box_format
(
bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
if
bbox
.
dtype
!=
dtype
:
if
bbox
.
dtype
!=
dtype
:
# Temporary cast to original dtype
# Temporary cast to original dtype
# e.g. float32 -> int
# e.g. float32 -> int
...
@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
]
)
)
bbox_xyxy
=
convert_bounding_box_format
(
bbox
,
old_format
=
format_
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
bbox_xyxy
=
convert_bounding_box_format
(
bbox
,
old_format
=
format_
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
points
=
np
.
array
(
points
=
np
.
array
(
[
[
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
...
@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
)
)
out_bbox
=
torch
.
from_numpy
(
out_bbox
)
out_bbox
=
torch
.
from_numpy
(
out_bbox
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
out_bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
)
)
return
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
).
to
(
bbox
)
return
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
).
to
(
bbox
)
...
@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
output_size_
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
output_size_
):
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
bbox
=
convert_bounding_box_format
(
bbox
.
float
(),
format_
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
)
bbox
=
convert_bounding_box_format
(
bbox
.
float
(),
format_
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
)
if
len
(
output_size_
)
==
1
:
if
len
(
output_size_
)
==
1
:
output_size_
.
append
(
output_size_
[
-
1
])
output_size_
.
append
(
output_size_
[
-
1
])
...
@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
...
@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox
[
3
].
item
(),
bbox
[
3
].
item
(),
]
]
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
output_size
)
out_bbox
=
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
output_size
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
...
@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
...
@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
torch
.
tensor
(
true_cv2_results
[
gt_key
]).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
torch
.
tensor
(
true_cv2_results
[
gt_key
]).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
)
)
image
=
datapoint
s
.
Image
(
tensor
)
image
=
tv_tensor
s
.
Image
(
tensor
)
out
=
fn
(
image
,
kernel_size
=
ksize
,
sigma
=
sigma
)
out
=
fn
(
image
,
kernel_size
=
ksize
,
sigma
=
sigma
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
msg
=
f
"
{
ksize
}
,
{
sigma
}
"
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
msg
=
f
"
{
ksize
}
,
{
sigma
}
"
)
...
...
test/test_transforms_v2_refactored.py
View file @
d5f4cc38
...
@@ -36,7 +36,7 @@ from torch import nn
...
@@ -36,7 +36,7 @@ from torch import nn
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
torch.utils.data
import
DataLoader
,
default_collate
from
torch.utils.data
import
DataLoader
,
default_collate
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.functional
import
pil_modes_mapping
...
@@ -167,7 +167,7 @@ def check_kernel(
...
@@ -167,7 +167,7 @@ def check_kernel(
def
_check_functional_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
):
def
_check_functional_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
):
"""Checks if the functional can be scripted and the scripted version can be called without error."""
"""Checks if the functional can be scripted and the scripted version can be called without error."""
if
not
isinstance
(
input
,
datapoint
s
.
Image
):
if
not
isinstance
(
input
,
tv_tensor
s
.
Image
):
return
return
functional_scripted
=
_script
(
functional
)
functional_scripted
=
_script
(
functional
)
...
@@ -187,7 +187,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
...
@@ -187,7 +187,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
assert
isinstance
(
output
,
type
(
input
))
assert
isinstance
(
output
,
type
(
input
))
if
isinstance
(
input
,
datapoint
s
.
BoundingBoxes
):
if
isinstance
(
input
,
tv_tensor
s
.
BoundingBoxes
):
assert
output
.
format
==
input
.
format
assert
output
.
format
==
input
.
format
if
check_scripted_smoke
:
if
check_scripted_smoke
:
...
@@ -199,11 +199,11 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
...
@@ -199,11 +199,11 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
functional_params
=
list
(
inspect
.
signature
(
functional
).
parameters
.
values
())[
1
:]
functional_params
=
list
(
inspect
.
signature
(
functional
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
if
issubclass
(
input_type
,
tv_tensors
.
TVTensor
):
# We filter out metadata that is implicitly passed to the functional through the input
datapoint
, but has to be
# We filter out metadata that is implicitly passed to the functional through the input
tv_tensor
, but has to be
# explicitly passed to the kernel.
# explicitly passed to the kernel.
explicit_metadata
=
{
explicit_metadata
=
{
datapoint
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
tv_tensor
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
...
@@ -264,7 +264,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
...
@@ -264,7 +264,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
output
=
transform
(
input
)
output
=
transform
(
input
)
assert
isinstance
(
output
,
type
(
input
))
assert
isinstance
(
output
,
type
(
input
))
if
isinstance
(
input
,
datapoint
s
.
BoundingBoxes
):
if
isinstance
(
input
,
tv_tensor
s
.
BoundingBoxes
):
assert
output
.
format
==
input
.
format
assert
output
.
format
==
input
.
format
if
check_v1_compatibility
:
if
check_v1_compatibility
:
...
@@ -362,7 +362,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
...
@@ -362,7 +362,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
input_xyxy
=
F
.
convert_bounding_box_format
(
input_xyxy
=
F
.
convert_bounding_box_format
(
bounding_boxes
.
to
(
torch
.
float64
,
copy
=
True
),
bounding_boxes
.
to
(
torch
.
float64
,
copy
=
True
),
old_format
=
format
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
,
inplace
=
True
,
)
)
x1
,
y1
,
x2
,
y2
=
input_xyxy
.
squeeze
(
0
).
tolist
()
x1
,
y1
,
x2
,
y2
=
input_xyxy
.
squeeze
(
0
).
tolist
()
...
@@ -387,7 +387,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
...
@@ -387,7 +387,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
)
)
output
=
F
.
convert_bounding_box_format
(
output
=
F
.
convert_bounding_box_format
(
output_xyxy
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
output_xyxy
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
)
if
clamp
:
if
clamp
:
...
@@ -400,7 +400,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
...
@@ -400,7 +400,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
return
output
return
output
return
datapoint
s
.
BoundingBoxes
(
return
tv_tensor
s
.
BoundingBoxes
(
torch
.
cat
([
affine_bounding_boxes
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()],
dim
=
0
).
reshape
(
torch
.
cat
([
affine_bounding_boxes
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()],
dim
=
0
).
reshape
(
bounding_boxes
.
shape
bounding_boxes
.
shape
),
),
...
@@ -479,7 +479,7 @@ class TestResize:
...
@@ -479,7 +479,7 @@ class TestResize:
check_scripted_vs_eager
=
not
isinstance
(
size
,
int
),
check_scripted_vs_eager
=
not
isinstance
(
size
,
int
),
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"use_max_size"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_max_size"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
...
@@ -529,10 +529,10 @@ class TestResize:
...
@@ -529,10 +529,10 @@ class TestResize:
[
[
(
F
.
resize_image
,
torch
.
Tensor
),
(
F
.
resize_image
,
torch
.
Tensor
),
(
F
.
_resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
resize_image
,
datapoint
s
.
Image
),
(
F
.
resize_image
,
tv_tensor
s
.
Image
),
(
F
.
resize_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
resize_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
resize_mask
,
datapoint
s
.
Mask
),
(
F
.
resize_mask
,
tv_tensor
s
.
Mask
),
(
F
.
resize_video
,
datapoint
s
.
Video
),
(
F
.
resize_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -605,7 +605,7 @@ class TestResize:
...
@@ -605,7 +605,7 @@ class TestResize:
new_canvas_size
=
(
new_height
,
new_width
),
new_canvas_size
=
(
new_height
,
new_width
),
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"use_max_size"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_max_size"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
resize
,
transform_cls_to_functional
(
transforms
.
Resize
)])
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
resize
,
transform_cls_to_functional
(
transforms
.
Resize
)])
...
@@ -734,9 +734,9 @@ class TestResize:
...
@@ -734,9 +734,9 @@ class TestResize:
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# is a good reason to break this, feel free to downgrade to an equality check.
# is a good reason to break this, feel free to downgrade to an equality check.
if
isinstance
(
input
,
datapoints
.
Datapoint
):
if
isinstance
(
input
,
tv_tensors
.
TVTensor
):
# We can't test identity directly, since that checks for the identity of the Python object. Since all
# We can't test identity directly, since that checks for the identity of the Python object. Since all
#
datapoint
s unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
#
tv_tensor
s unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
# that the underlying storage is the same
# that the underlying storage is the same
assert
output
.
data_ptr
()
==
input
.
data_ptr
()
assert
output
.
data_ptr
()
==
input
.
data_ptr
()
else
:
else
:
...
@@ -782,7 +782,7 @@ class TestResize:
...
@@ -782,7 +782,7 @@ class TestResize:
)
)
if
emulate_channels_last
:
if
emulate_channels_last
:
image
=
datapoint
s
.
wrap
(
image
.
view
(
*
batch_dims
,
*
image
.
shape
[
-
3
:]),
like
=
image
)
image
=
tv_tensor
s
.
wrap
(
image
.
view
(
*
batch_dims
,
*
image
.
shape
[
-
3
:]),
like
=
image
)
return
image
return
image
...
@@ -833,7 +833,7 @@ class TestHorizontalFlip:
...
@@ -833,7 +833,7 @@ class TestHorizontalFlip:
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
horizontal_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
horizontal_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
...
@@ -864,10 +864,10 @@ class TestHorizontalFlip:
...
@@ -864,10 +864,10 @@ class TestHorizontalFlip:
[
[
(
F
.
horizontal_flip_image
,
torch
.
Tensor
),
(
F
.
horizontal_flip_image
,
torch
.
Tensor
),
(
F
.
_horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
horizontal_flip_image
,
datapoint
s
.
Image
),
(
F
.
horizontal_flip_image
,
tv_tensor
s
.
Image
),
(
F
.
horizontal_flip_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
horizontal_flip_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
horizontal_flip_mask
,
datapoint
s
.
Mask
),
(
F
.
horizontal_flip_mask
,
tv_tensor
s
.
Mask
),
(
F
.
horizontal_flip_video
,
datapoint
s
.
Video
),
(
F
.
horizontal_flip_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -902,7 +902,7 @@ class TestHorizontalFlip:
...
@@ -902,7 +902,7 @@ class TestHorizontalFlip:
return
reference_affine_bounding_boxes_helper
(
bounding_boxes
,
affine_matrix
=
affine_matrix
)
return
reference_affine_bounding_boxes_helper
(
bounding_boxes
,
affine_matrix
=
affine_matrix
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
horizontal_flip
,
transform_cls_to_functional
(
transforms
.
RandomHorizontalFlip
,
p
=
1
)]
"fn"
,
[
F
.
horizontal_flip
,
transform_cls_to_functional
(
transforms
.
RandomHorizontalFlip
,
p
=
1
)]
)
)
...
@@ -999,7 +999,7 @@ class TestAffine:
...
@@ -999,7 +999,7 @@ class TestAffine:
shear
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"shear"
],
shear
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"shear"
],
center
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"center"
],
center
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"center"
],
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_bounding_boxes
(
self
,
param
,
value
,
format
,
dtype
,
device
):
def
test_kernel_bounding_boxes
(
self
,
param
,
value
,
format
,
dtype
,
device
):
...
@@ -1032,10 +1032,10 @@ class TestAffine:
...
@@ -1032,10 +1032,10 @@ class TestAffine:
[
[
(
F
.
affine_image
,
torch
.
Tensor
),
(
F
.
affine_image
,
torch
.
Tensor
),
(
F
.
_affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
affine_image
,
datapoint
s
.
Image
),
(
F
.
affine_image
,
tv_tensor
s
.
Image
),
(
F
.
affine_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
affine_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
affine_mask
,
datapoint
s
.
Mask
),
(
F
.
affine_mask
,
tv_tensor
s
.
Mask
),
(
F
.
affine_video
,
datapoint
s
.
Video
),
(
F
.
affine_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -1148,7 +1148,7 @@ class TestAffine:
...
@@ -1148,7 +1148,7 @@ class TestAffine:
),
),
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"angle"
,
_CORRECTNESS_AFFINE_KWARGS
[
"angle"
])
@
pytest
.
mark
.
parametrize
(
"angle"
,
_CORRECTNESS_AFFINE_KWARGS
[
"angle"
])
@
pytest
.
mark
.
parametrize
(
"translate"
,
_CORRECTNESS_AFFINE_KWARGS
[
"translate"
])
@
pytest
.
mark
.
parametrize
(
"translate"
,
_CORRECTNESS_AFFINE_KWARGS
[
"translate"
])
@
pytest
.
mark
.
parametrize
(
"scale"
,
_CORRECTNESS_AFFINE_KWARGS
[
"scale"
])
@
pytest
.
mark
.
parametrize
(
"scale"
,
_CORRECTNESS_AFFINE_KWARGS
[
"scale"
])
...
@@ -1176,7 +1176,7 @@ class TestAffine:
...
@@ -1176,7 +1176,7 @@ class TestAffine:
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
def
test_transform_bounding_boxes_correctness
(
self
,
format
,
center
,
seed
):
def
test_transform_bounding_boxes_correctness
(
self
,
format
,
center
,
seed
):
...
@@ -1283,7 +1283,7 @@ class TestVerticalFlip:
...
@@ -1283,7 +1283,7 @@ class TestVerticalFlip:
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
vertical_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
vertical_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
...
@@ -1314,10 +1314,10 @@ class TestVerticalFlip:
...
@@ -1314,10 +1314,10 @@ class TestVerticalFlip:
[
[
(
F
.
vertical_flip_image
,
torch
.
Tensor
),
(
F
.
vertical_flip_image
,
torch
.
Tensor
),
(
F
.
_vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
vertical_flip_image
,
datapoint
s
.
Image
),
(
F
.
vertical_flip_image
,
tv_tensor
s
.
Image
),
(
F
.
vertical_flip_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
vertical_flip_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
vertical_flip_mask
,
datapoint
s
.
Mask
),
(
F
.
vertical_flip_mask
,
tv_tensor
s
.
Mask
),
(
F
.
vertical_flip_video
,
datapoint
s
.
Video
),
(
F
.
vertical_flip_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -1350,7 +1350,7 @@ class TestVerticalFlip:
...
@@ -1350,7 +1350,7 @@ class TestVerticalFlip:
return
reference_affine_bounding_boxes_helper
(
bounding_boxes
,
affine_matrix
=
affine_matrix
)
return
reference_affine_bounding_boxes_helper
(
bounding_boxes
,
affine_matrix
=
affine_matrix
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
vertical_flip
,
transform_cls_to_functional
(
transforms
.
RandomVerticalFlip
,
p
=
1
)])
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
vertical_flip
,
transform_cls_to_functional
(
transforms
.
RandomVerticalFlip
,
p
=
1
)])
def
test_bounding_boxes_correctness
(
self
,
format
,
fn
):
def
test_bounding_boxes_correctness
(
self
,
format
,
fn
):
bounding_boxes
=
make_bounding_boxes
(
format
=
format
)
bounding_boxes
=
make_bounding_boxes
(
format
=
format
)
...
@@ -1419,7 +1419,7 @@ class TestRotate:
...
@@ -1419,7 +1419,7 @@ class TestRotate:
expand
=
[
False
,
True
],
expand
=
[
False
,
True
],
center
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"center"
],
center
=
_EXHAUSTIVE_TYPE_AFFINE_KWARGS
[
"center"
],
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_bounding_boxes
(
self
,
param
,
value
,
format
,
dtype
,
device
):
def
test_kernel_bounding_boxes
(
self
,
param
,
value
,
format
,
dtype
,
device
):
...
@@ -1456,10 +1456,10 @@ class TestRotate:
...
@@ -1456,10 +1456,10 @@ class TestRotate:
[
[
(
F
.
rotate_image
,
torch
.
Tensor
),
(
F
.
rotate_image
,
torch
.
Tensor
),
(
F
.
_rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
rotate_image
,
datapoint
s
.
Image
),
(
F
.
rotate_image
,
tv_tensor
s
.
Image
),
(
F
.
rotate_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
rotate_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
rotate_mask
,
datapoint
s
.
Mask
),
(
F
.
rotate_mask
,
tv_tensor
s
.
Mask
),
(
F
.
rotate_video
,
datapoint
s
.
Video
),
(
F
.
rotate_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -1553,11 +1553,11 @@ class TestRotate:
...
@@ -1553,11 +1553,11 @@ class TestRotate:
def
_recenter_bounding_boxes_after_expand
(
self
,
bounding_boxes
,
*
,
recenter_xy
):
def
_recenter_bounding_boxes_after_expand
(
self
,
bounding_boxes
,
*
,
recenter_xy
):
x
,
y
=
recenter_xy
x
,
y
=
recenter_xy
if
bounding_boxes
.
format
is
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
bounding_boxes
.
format
is
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
translate
=
[
x
,
y
,
x
,
y
]
translate
=
[
x
,
y
,
x
,
y
]
else
:
else
:
translate
=
[
x
,
y
,
0.0
,
0.0
]
translate
=
[
x
,
y
,
0.0
,
0.0
]
return
datapoint
s
.
wrap
(
return
tv_tensor
s
.
wrap
(
(
bounding_boxes
.
to
(
torch
.
float64
)
-
torch
.
tensor
(
translate
)).
to
(
bounding_boxes
.
dtype
),
like
=
bounding_boxes
(
bounding_boxes
.
to
(
torch
.
float64
)
-
torch
.
tensor
(
translate
)).
to
(
bounding_boxes
.
dtype
),
like
=
bounding_boxes
)
)
...
@@ -1590,7 +1590,7 @@ class TestRotate:
...
@@ -1590,7 +1590,7 @@ class TestRotate:
bounding_boxes
bounding_boxes
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"angle"
,
_CORRECTNESS_AFFINE_KWARGS
[
"angle"
])
@
pytest
.
mark
.
parametrize
(
"angle"
,
_CORRECTNESS_AFFINE_KWARGS
[
"angle"
])
@
pytest
.
mark
.
parametrize
(
"expand"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"expand"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
...
@@ -1603,7 +1603,7 @@ class TestRotate:
...
@@ -1603,7 +1603,7 @@ class TestRotate:
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
F
.
get_size
(
actual
),
F
.
get_size
(
expected
),
atol
=
2
if
expand
else
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
F
.
get_size
(
actual
),
F
.
get_size
(
expected
),
atol
=
2
if
expand
else
0
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"expand"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"expand"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
@
pytest
.
mark
.
parametrize
(
"center"
,
_CORRECTNESS_AFFINE_KWARGS
[
"center"
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
...
@@ -1861,7 +1861,7 @@ class TestToDtype:
...
@@ -1861,7 +1861,7 @@ class TestToDtype:
# make sure "others" works as a catch-all and that None means no conversion
# make sure "others" works as a catch-all and that None means no conversion
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
out
=
transforms
.
ToDtype
(
dtype
=
{
datapoint
s
.
Mask
:
torch
.
int64
,
"others"
:
None
})(
sample
)
out
=
transforms
.
ToDtype
(
dtype
=
{
tv_tensor
s
.
Mask
:
torch
.
int64
,
"others"
:
None
})(
sample
)
assert
out
[
"inpt"
].
dtype
==
inpt_dtype
assert
out
[
"inpt"
].
dtype
==
inpt_dtype
assert
out
[
"bbox"
].
dtype
==
bbox_dtype
assert
out
[
"bbox"
].
dtype
==
bbox_dtype
assert
out
[
"mask"
].
dtype
!=
mask_dtype
assert
out
[
"mask"
].
dtype
!=
mask_dtype
...
@@ -1874,7 +1874,7 @@ class TestToDtype:
...
@@ -1874,7 +1874,7 @@ class TestToDtype:
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
out
=
transforms
.
ToDtype
(
out
=
transforms
.
ToDtype
(
dtype
=
{
type
(
sample
[
"inpt"
]):
torch
.
float32
,
datapoint
s
.
Mask
:
torch
.
int64
,
"others"
:
None
},
scale
=
True
dtype
=
{
type
(
sample
[
"inpt"
]):
torch
.
float32
,
tv_tensor
s
.
Mask
:
torch
.
int64
,
"others"
:
None
},
scale
=
True
)(
sample
)
)(
sample
)
assert
out
[
"inpt"
].
dtype
!=
inpt_dtype
assert
out
[
"inpt"
].
dtype
!=
inpt_dtype
assert
out
[
"inpt"
].
dtype
==
torch
.
float32
assert
out
[
"inpt"
].
dtype
==
torch
.
float32
...
@@ -1888,9 +1888,9 @@ class TestToDtype:
...
@@ -1888,9 +1888,9 @@ class TestToDtype:
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
sample
,
inpt_dtype
,
bbox_dtype
,
mask_dtype
=
self
.
make_inpt_with_bbox_and_mask
(
make_input
)
with
pytest
.
raises
(
ValueError
,
match
=
"No dtype was specified for"
):
with
pytest
.
raises
(
ValueError
,
match
=
"No dtype was specified for"
):
out
=
transforms
.
ToDtype
(
dtype
=
{
datapoint
s
.
Mask
:
torch
.
float32
})(
sample
)
out
=
transforms
.
ToDtype
(
dtype
=
{
tv_tensor
s
.
Mask
:
torch
.
float32
})(
sample
)
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"plain `torch.Tensor` will *not* be transformed"
)):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"plain `torch.Tensor` will *not* be transformed"
)):
transforms
.
ToDtype
(
dtype
=
{
torch
.
Tensor
:
torch
.
float32
,
datapoint
s
.
Image
:
torch
.
float32
})
transforms
.
ToDtype
(
dtype
=
{
torch
.
Tensor
:
torch
.
float32
,
tv_tensor
s
.
Image
:
torch
.
float32
})
with
pytest
.
warns
(
UserWarning
,
match
=
"no scaling will be done"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"no scaling will be done"
):
out
=
transforms
.
ToDtype
(
dtype
=
{
"others"
:
None
},
scale
=
True
)(
sample
)
out
=
transforms
.
ToDtype
(
dtype
=
{
"others"
:
None
},
scale
=
True
)(
sample
)
assert
out
[
"inpt"
].
dtype
==
inpt_dtype
assert
out
[
"inpt"
].
dtype
==
inpt_dtype
...
@@ -1923,8 +1923,8 @@ class TestAdjustBrightness:
...
@@ -1923,8 +1923,8 @@ class TestAdjustBrightness:
[
[
(
F
.
adjust_brightness_image
,
torch
.
Tensor
),
(
F
.
adjust_brightness_image
,
torch
.
Tensor
),
(
F
.
_adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
adjust_brightness_image
,
datapoint
s
.
Image
),
(
F
.
adjust_brightness_image
,
tv_tensor
s
.
Image
),
(
F
.
adjust_brightness_video
,
datapoint
s
.
Video
),
(
F
.
adjust_brightness_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -2028,8 +2028,8 @@ class TestCutMixMixUp:
...
@@ -2028,8 +2028,8 @@ class TestCutMixMixUp:
for
input_with_bad_type
in
(
for
input_with_bad_type
in
(
F
.
to_pil_image
(
imgs
[
0
]),
F
.
to_pil_image
(
imgs
[
0
]),
datapoint
s
.
Mask
(
torch
.
rand
(
12
,
12
)),
tv_tensor
s
.
Mask
(
torch
.
rand
(
12
,
12
)),
datapoint
s
.
BoundingBoxes
(
torch
.
rand
(
2
,
4
),
format
=
"XYXY"
,
canvas_size
=
12
),
tv_tensor
s
.
BoundingBoxes
(
torch
.
rand
(
2
,
4
),
format
=
"XYXY"
,
canvas_size
=
12
),
):
):
with
pytest
.
raises
(
ValueError
,
match
=
"does not support PIL images, "
):
with
pytest
.
raises
(
ValueError
,
match
=
"does not support PIL images, "
):
cutmix_mixup
(
input_with_bad_type
)
cutmix_mixup
(
input_with_bad_type
)
...
@@ -2172,12 +2172,12 @@ class TestShapeGetters:
...
@@ -2172,12 +2172,12 @@ class TestShapeGetters:
class
TestRegisterKernel
:
class
TestRegisterKernel
:
@
pytest
.
mark
.
parametrize
(
"functional"
,
(
F
.
resize
,
"resize"
))
@
pytest
.
mark
.
parametrize
(
"functional"
,
(
F
.
resize
,
"resize"
))
def
test_register_kernel
(
self
,
functional
):
def
test_register_kernel
(
self
,
functional
):
class
Custom
Datapoint
(
datapoints
.
Datapoint
):
class
Custom
TVTensor
(
tv_tensors
.
TVTensor
):
pass
pass
kernel_was_called
=
False
kernel_was_called
=
False
@
F
.
register_kernel
(
functional
,
Custom
Datapoint
)
@
F
.
register_kernel
(
functional
,
Custom
TVTensor
)
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
nonlocal
kernel_was_called
nonlocal
kernel_was_called
kernel_was_called
=
True
kernel_was_called
=
True
...
@@ -2185,38 +2185,38 @@ class TestRegisterKernel:
...
@@ -2185,38 +2185,38 @@ class TestRegisterKernel:
t
=
transforms
.
Resize
(
size
=
(
224
,
224
),
antialias
=
True
)
t
=
transforms
.
Resize
(
size
=
(
224
,
224
),
antialias
=
True
)
my_dp
=
Custom
Datapoint
(
torch
.
rand
(
3
,
10
,
10
))
my_dp
=
Custom
TVTensor
(
torch
.
rand
(
3
,
10
,
10
))
out
=
t
(
my_dp
)
out
=
t
(
my_dp
)
assert
out
is
my_dp
assert
out
is
my_dp
assert
kernel_was_called
assert
kernel_was_called
# Sanity check to make sure we didn't override the kernel of other types
# Sanity check to make sure we didn't override the kernel of other types
t
(
torch
.
rand
(
3
,
10
,
10
)).
shape
==
(
3
,
224
,
224
)
t
(
torch
.
rand
(
3
,
10
,
10
)).
shape
==
(
3
,
224
,
224
)
t
(
datapoint
s
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
t
(
tv_tensor
s
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
def
test_errors
(
self
):
def
test_errors
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find functional with name"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find functional with name"
):
F
.
register_kernel
(
"bad_name"
,
datapoint
s
.
Image
)
F
.
register_kernel
(
"bad_name"
,
tv_tensor
s
.
Image
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on functionals"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on functionals"
):
F
.
register_kernel
(
datapoint
s
.
Image
,
F
.
resize
)
F
.
register_kernel
(
tv_tensor
s
.
Image
,
F
.
resize
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
F
.
register_kernel
(
F
.
resize
,
object
)
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be registered for the builtin
datapoint
classes"
):
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be registered for the builtin
tv_tensor
classes"
):
F
.
register_kernel
(
F
.
resize
,
datapoint
s
.
Image
)(
F
.
resize_image
)
F
.
register_kernel
(
F
.
resize
,
tv_tensor
s
.
Image
)(
F
.
resize_image
)
class
Custom
Datapoint
(
datapoints
.
Datapoint
):
class
Custom
TVTensor
(
tv_tensors
.
TVTensor
):
pass
pass
def
resize_custom_
datapoint
():
def
resize_custom_
tv_tensor
():
pass
pass
F
.
register_kernel
(
F
.
resize
,
Custom
Datapoint
)(
resize_custom_
datapoint
)
F
.
register_kernel
(
F
.
resize
,
Custom
TVTensor
)(
resize_custom_
tv_tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"already has a kernel registered for type"
):
with
pytest
.
raises
(
ValueError
,
match
=
"already has a kernel registered for type"
):
F
.
register_kernel
(
F
.
resize
,
Custom
Datapoint
)(
resize_custom_
datapoint
)
F
.
register_kernel
(
F
.
resize
,
Custom
TVTensor
)(
resize_custom_
tv_tensor
)
class
TestGetKernel
:
class
TestGetKernel
:
...
@@ -2225,10 +2225,10 @@ class TestGetKernel:
...
@@ -2225,10 +2225,10 @@ class TestGetKernel:
KERNELS
=
{
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image
,
torch
.
Tensor
:
F
.
resize_image
,
PIL
.
Image
.
Image
:
F
.
_resize_image_pil
,
PIL
.
Image
.
Image
:
F
.
_resize_image_pil
,
datapoint
s
.
Image
:
F
.
resize_image
,
tv_tensor
s
.
Image
:
F
.
resize_image
,
datapoint
s
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
resize_mask
,
tv_tensor
s
.
Mask
:
F
.
resize_mask
,
datapoint
s
.
Video
:
F
.
resize_video
,
tv_tensor
s
.
Video
:
F
.
resize_video
,
}
}
@
pytest
.
mark
.
parametrize
(
"input_type"
,
[
str
,
int
,
object
])
@
pytest
.
mark
.
parametrize
(
"input_type"
,
[
str
,
int
,
object
])
...
@@ -2244,57 +2244,57 @@ class TestGetKernel:
...
@@ -2244,57 +2244,57 @@ class TestGetKernel:
pass
pass
for
input_type
,
kernel
in
self
.
KERNELS
.
items
():
for
input_type
,
kernel
in
self
.
KERNELS
.
items
():
_register_kernel_internal
(
resize_with_pure_kernels
,
input_type
,
datapoint
_wrapper
=
False
)(
kernel
)
_register_kernel_internal
(
resize_with_pure_kernels
,
input_type
,
tv_tensor
_wrapper
=
False
)(
kernel
)
assert
_get_kernel
(
resize_with_pure_kernels
,
input_type
)
is
kernel
assert
_get_kernel
(
resize_with_pure_kernels
,
input_type
)
is
kernel
def
test_builtin_
datapoint
_subclass
(
self
):
def
test_builtin_
tv_tensor
_subclass
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check if subclasses of our builtin
datapoint
s get dispatched
# here, register the kernels without wrapper, and check if subclasses of our builtin
tv_tensor
s get dispatched
# to the kernel of the corresponding superclass
# to the kernel of the corresponding superclass
def
resize_with_pure_kernels
():
def
resize_with_pure_kernels
():
pass
pass
class
MyImage
(
datapoint
s
.
Image
):
class
MyImage
(
tv_tensor
s
.
Image
):
pass
pass
class
MyBoundingBoxes
(
datapoint
s
.
BoundingBoxes
):
class
MyBoundingBoxes
(
tv_tensor
s
.
BoundingBoxes
):
pass
pass
class
MyMask
(
datapoint
s
.
Mask
):
class
MyMask
(
tv_tensor
s
.
Mask
):
pass
pass
class
MyVideo
(
datapoint
s
.
Video
):
class
MyVideo
(
tv_tensor
s
.
Video
):
pass
pass
for
custom_
datapoint
_subclass
in
[
for
custom_
tv_tensor
_subclass
in
[
MyImage
,
MyImage
,
MyBoundingBoxes
,
MyBoundingBoxes
,
MyMask
,
MyMask
,
MyVideo
,
MyVideo
,
]:
]:
builtin_
datapoint
_class
=
custom_
datapoint
_subclass
.
__mro__
[
1
]
builtin_
tv_tensor
_class
=
custom_
tv_tensor
_subclass
.
__mro__
[
1
]
builtin_
datapoint
_kernel
=
self
.
KERNELS
[
builtin_
datapoint
_class
]
builtin_
tv_tensor
_kernel
=
self
.
KERNELS
[
builtin_
tv_tensor
_class
]
_register_kernel_internal
(
resize_with_pure_kernels
,
builtin_
datapoint_class
,
datapoint
_wrapper
=
False
)(
_register_kernel_internal
(
resize_with_pure_kernels
,
builtin_
tv_tensor_class
,
tv_tensor
_wrapper
=
False
)(
builtin_
datapoint
_kernel
builtin_
tv_tensor
_kernel
)
)
assert
_get_kernel
(
resize_with_pure_kernels
,
custom_
datapoint
_subclass
)
is
builtin_
datapoint
_kernel
assert
_get_kernel
(
resize_with_pure_kernels
,
custom_
tv_tensor
_subclass
)
is
builtin_
tv_tensor
_kernel
def
test_
datapoint
_subclass
(
self
):
def
test_
tv_tensor
_subclass
(
self
):
class
My
Datapoint
(
datapoints
.
Datapoint
):
class
My
TVTensor
(
tv_tensors
.
TVTensor
):
pass
pass
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
_get_kernel
(
F
.
resize
,
My
Datapoint
)
_get_kernel
(
F
.
resize
,
My
TVTensor
)
def
resize_my_
datapoint
():
def
resize_my_
tv_tensor
():
pass
pass
_register_kernel_internal
(
F
.
resize
,
My
Datapoint
,
datapoint
_wrapper
=
False
)(
resize_my_
datapoint
)
_register_kernel_internal
(
F
.
resize
,
My
TVTensor
,
tv_tensor
_wrapper
=
False
)(
resize_my_
tv_tensor
)
assert
_get_kernel
(
F
.
resize
,
My
Datapoint
)
is
resize_my_
datapoint
assert
_get_kernel
(
F
.
resize
,
My
TVTensor
)
is
resize_my_
tv_tensor
def
test_pil_image_subclass
(
self
):
def
test_pil_image_subclass
(
self
):
opened_image
=
PIL
.
Image
.
open
(
Path
(
__file__
).
parent
/
"assets"
/
"encode_jpeg"
/
"grace_hopper_517x606.jpg"
)
opened_image
=
PIL
.
Image
.
open
(
Path
(
__file__
).
parent
/
"assets"
/
"encode_jpeg"
/
"grace_hopper_517x606.jpg"
)
...
@@ -2342,8 +2342,8 @@ class TestPermuteChannels:
...
@@ -2342,8 +2342,8 @@ class TestPermuteChannels:
[
[
(
F
.
permute_channels_image
,
torch
.
Tensor
),
(
F
.
permute_channels_image
,
torch
.
Tensor
),
(
F
.
_permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
permute_channels_image
,
datapoint
s
.
Image
),
(
F
.
permute_channels_image
,
tv_tensor
s
.
Image
),
(
F
.
permute_channels_video
,
datapoint
s
.
Video
),
(
F
.
permute_channels_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -2352,7 +2352,7 @@ class TestPermuteChannels:
...
@@ -2352,7 +2352,7 @@ class TestPermuteChannels:
def
reference_image_correctness
(
self
,
image
,
permutation
):
def
reference_image_correctness
(
self
,
image
,
permutation
):
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
permuted_channel_images
=
[
channel_images
[
channel_idx
]
for
channel_idx
in
permutation
]
permuted_channel_images
=
[
channel_images
[
channel_idx
]
for
channel_idx
in
permutation
]
return
datapoint
s
.
Image
(
torch
.
concat
(
permuted_channel_images
,
dim
=-
3
))
return
tv_tensor
s
.
Image
(
torch
.
concat
(
permuted_channel_images
,
dim
=-
3
))
@
pytest
.
mark
.
parametrize
(
"permutation"
,
[[
2
,
0
,
1
],
[
1
,
2
,
0
],
[
2
,
0
,
1
],
[
0
,
1
,
2
]])
@
pytest
.
mark
.
parametrize
(
"permutation"
,
[[
2
,
0
,
1
],
[
1
,
2
,
0
],
[
2
,
0
,
1
],
[
0
,
1
,
2
]])
@
pytest
.
mark
.
parametrize
(
"batch_dims"
,
[(),
(
2
,),
(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"batch_dims"
,
[(),
(
2
,),
(
2
,
1
)])
...
@@ -2392,7 +2392,7 @@ class TestElastic:
...
@@ -2392,7 +2392,7 @@ class TestElastic:
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoint
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensor
s
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
def
test_kernel_bounding_boxes
(
self
,
format
,
dtype
,
device
):
...
@@ -2428,10 +2428,10 @@ class TestElastic:
...
@@ -2428,10 +2428,10 @@ class TestElastic:
[
[
(
F
.
elastic_image
,
torch
.
Tensor
),
(
F
.
elastic_image
,
torch
.
Tensor
),
(
F
.
_elastic_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_elastic_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
elastic_image
,
datapoint
s
.
Image
),
(
F
.
elastic_image
,
tv_tensor
s
.
Image
),
(
F
.
elastic_bounding_boxes
,
datapoint
s
.
BoundingBoxes
),
(
F
.
elastic_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
),
(
F
.
elastic_mask
,
datapoint
s
.
Mask
),
(
F
.
elastic_mask
,
tv_tensor
s
.
Mask
),
(
F
.
elastic_video
,
datapoint
s
.
Video
),
(
F
.
elastic_video
,
tv_tensor
s
.
Video
),
],
],
)
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
def
test_functional_signature
(
self
,
kernel
,
input_type
):
...
@@ -2481,7 +2481,7 @@ class TestToPureTensor:
...
@@ -2481,7 +2481,7 @@ class TestToPureTensor:
out
=
transforms
.
ToPureTensor
()(
input
)
out
=
transforms
.
ToPureTensor
()(
input
)
for
input_value
,
out_value
in
zip
(
input
.
values
(),
out
.
values
()):
for
input_value
,
out_value
in
zip
(
input
.
values
(),
out
.
values
()):
if
isinstance
(
input_value
,
datapoints
.
Datapoint
):
if
isinstance
(
input_value
,
tv_tensors
.
TVTensor
):
assert
isinstance
(
out_value
,
torch
.
Tensor
)
and
not
isinstance
(
out_value
,
datapoints
.
Datapoint
)
assert
isinstance
(
out_value
,
torch
.
Tensor
)
and
not
isinstance
(
out_value
,
tv_tensors
.
TVTensor
)
else
:
else
:
assert
isinstance
(
out_value
,
type
(
input_value
))
assert
isinstance
(
out_value
,
type
(
input_value
))
test/test_transforms_v2_utils.py
View file @
d5f4cc38
...
@@ -6,46 +6,46 @@ import torch
...
@@ -6,46 +6,46 @@ import torch
import
torchvision.transforms.v2._utils
import
torchvision.transforms.v2._utils
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.v2._utils
import
has_all
,
has_any
from
torchvision.transforms.v2._utils
import
has_all
,
has_any
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.functional
import
to_pil_image
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
BOUNDING_BOX
=
make_bounding_boxes
(
DEFAULT_SIZE
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
BOUNDING_BOX
=
make_bounding_boxes
(
DEFAULT_SIZE
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
MASK
=
make_detection_mask
(
DEFAULT_SIZE
)
MASK
=
make_detection_mask
(
DEFAULT_SIZE
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"sample"
,
"types"
,
"expected"
),
(
"sample"
,
"types"
,
"expected"
),
[
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
),
((
MASK
,),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
False
),
((
MASK
,),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoint
s
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
tv_tensor
s
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
),
((
IMAGE
,),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
),
(
(
(
torch
.
Tensor
(
IMAGE
),),
(
torch
.
Tensor
(
IMAGE
),),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
,
True
,
),
),
(
(
(
to_pil_image
(
IMAGE
),),
(
to_pil_image
(
IMAGE
),),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
,
True
,
),
),
],
],
...
@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
...
@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"sample"
,
"types"
,
"expected"
),
(
"sample"
,
"types"
,
"expected"
),
[
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)),),
(
lambda
obj
:
isinstance
(
obj
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)),),
True
,
True
,
),
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
...
...
test/test_
datapoint
s.py
→
test/test_
tv_tensor
s.py
View file @
d5f4cc38
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
common_utils
import
assert_equal
,
make_bounding_boxes
,
make_image
,
make_segmentation_mask
,
make_video
from
common_utils
import
assert_equal
,
make_bounding_boxes
,
make_image
,
make_segmentation_mask
,
make_video
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -13,40 +13,40 @@ def restore_tensor_return_type():
...
@@ -13,40 +13,40 @@ def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
# (at least at the time of writing...)
yield
yield
datapoint
s
.
set_return_type
(
"Tensor"
)
tv_tensor
s
.
set_return_type
(
"Tensor"
)
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
def
test_image_instance
(
data
):
def
test_image_instance
(
data
):
image
=
datapoint
s
.
Image
(
data
)
image
=
tv_tensor
s
.
Image
(
data
)
assert
isinstance
(
image
,
torch
.
Tensor
)
assert
isinstance
(
image
,
torch
.
Tensor
)
assert
image
.
ndim
==
3
and
image
.
shape
[
0
]
==
3
assert
image
.
ndim
==
3
and
image
.
shape
[
0
]
==
3
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
10
,
size
=
(
1
,
32
,
32
)),
Image
.
new
(
"L"
,
(
32
,
32
),
color
=
2
)])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
10
,
size
=
(
1
,
32
,
32
)),
Image
.
new
(
"L"
,
(
32
,
32
),
color
=
2
)])
def
test_mask_instance
(
data
):
def
test_mask_instance
(
data
):
mask
=
datapoint
s
.
Mask
(
data
)
mask
=
tv_tensor
s
.
Mask
(
data
)
assert
isinstance
(
mask
,
torch
.
Tensor
)
assert
isinstance
(
mask
,
torch
.
Tensor
)
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
[
"XYXY"
,
"CXCYWH"
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
]
"format"
,
[
"XYXY"
,
"CXCYWH"
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
]
)
)
def
test_bbox_instance
(
data
,
format
):
def
test_bbox_instance
(
data
,
format
):
bboxes
=
datapoint
s
.
BoundingBoxes
(
data
,
format
=
format
,
canvas_size
=
(
32
,
32
))
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
data
,
format
=
format
,
canvas_size
=
(
32
,
32
))
assert
isinstance
(
bboxes
,
torch
.
Tensor
)
assert
isinstance
(
bboxes
,
torch
.
Tensor
)
assert
bboxes
.
ndim
==
2
and
bboxes
.
shape
[
1
]
==
4
assert
bboxes
.
ndim
==
2
and
bboxes
.
shape
[
1
]
==
4
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[(
format
.
upper
())]
format
=
tv_tensor
s
.
BoundingBoxFormat
[(
format
.
upper
())]
assert
bboxes
.
format
==
format
assert
bboxes
.
format
==
format
def
test_bbox_dim_error
():
def
test_bbox_dim_error
():
data_3d
=
[[[
1
,
2
,
3
,
4
]]]
data_3d
=
[[[
1
,
2
,
3
,
4
]]]
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a 1D or 2D tensor, got 3D"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a 1D or 2D tensor, got 3D"
):
datapoint
s
.
BoundingBoxes
(
data_3d
,
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
))
tv_tensor
s
.
BoundingBoxes
(
data_3d
,
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -64,8 +64,8 @@ def test_bbox_dim_error():
...
@@ -64,8 +64,8 @@ def test_bbox_dim_error():
],
],
)
)
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
datapoint
=
datapoint
s
.
Image
(
data
,
requires_grad
=
input_requires_grad
)
tv_tensor
=
tv_tensor
s
.
Image
(
data
,
requires_grad
=
input_requires_grad
)
assert
datapoint
.
requires_grad
is
expected_requires_grad
assert
tv_tensor
.
requires_grad
is
expected_requires_grad
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
@@ -75,7 +75,7 @@ def test_isinstance(make_input):
...
@@ -75,7 +75,7 @@ def test_isinstance(make_input):
def
test_wrapping_no_copy
():
def
test_wrapping_no_copy
():
tensor
=
torch
.
rand
(
3
,
16
,
16
)
tensor
=
torch
.
rand
(
3
,
16
,
16
)
image
=
datapoint
s
.
Image
(
tensor
)
image
=
tv_tensor
s
.
Image
(
tensor
)
assert
image
.
data_ptr
()
==
tensor
.
data_ptr
()
assert
image
.
data_ptr
()
==
tensor
.
data_ptr
()
...
@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
...
@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_to_
datapoint
_reference
(
make_input
,
return_type
):
def
test_to_
tv_tensor
_reference
(
make_input
,
return_type
):
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
tensor_to
=
tensor
.
to
(
dp
)
tensor_to
=
tensor
.
to
(
dp
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_clone_wrapping
(
make_input
,
return_type
):
def
test_clone_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_clone
=
dp
.
clone
()
dp_clone
=
dp
.
clone
()
assert
type
(
dp_clone
)
is
type
(
dp
)
assert
type
(
dp_clone
)
is
type
(
dp
)
...
@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
...
@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
)
dp
=
make_input
(
dtype
=
torch
.
float
)
assert
not
dp
.
requires_grad
assert
not
dp
.
requires_grad
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_requires_grad
=
dp
.
requires_grad_
(
True
)
dp_requires_grad
=
dp
.
requires_grad_
(
True
)
assert
type
(
dp_requires_grad
)
is
type
(
dp
)
assert
type
(
dp_requires_grad
)
is
type
(
dp
)
...
@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
...
@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_detach_wrapping
(
make_input
,
return_type
):
def
test_detach_wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_detached
=
dp
.
detach
()
dp_detached
=
dp
.
detach
()
assert
type
(
dp_detached
)
is
type
(
dp
)
assert
type
(
dp_detached
)
is
type
(
dp
)
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_force_subclass_with_metadata
(
return_type
):
def
test_force_subclass_with_metadata
(
return_type
):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and
datapoint
s with metadata
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and
tv_tensor
s with metadata
# Largely the same as above, we additionally check that the metadata is preserved
# Largely the same as above, we additionally check that the metadata is preserved
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
bbox
=
datapoint
s
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
bbox
=
tv_tensor
s
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
datapoint
s
.
set_return_type
(
return_type
)
tv_tensor
s
.
set_return_type
(
return_type
)
bbox
=
bbox
.
clone
()
bbox
=
bbox
.
clone
()
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
bbox
=
bbox
.
to
(
torch
.
float64
)
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
bbox
=
bbox
.
detach
()
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
bbox
.
requires_grad_
(
True
)
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
assert
bbox
.
requires_grad
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
output
=
dp
*
2
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
...
@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
original_type
=
type
(
dp
)
original_type
=
type
(
dp
)
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
output
=
dp
.
add_
(
0
)
output
=
dp
.
add_
(
0
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
assert
type
(
dp
)
is
original_type
assert
type
(
dp
)
is
original_type
...
@@ -219,7 +219,7 @@ def test_wrap(make_input):
...
@@ -219,7 +219,7 @@ def test_wrap(make_input):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
output
=
dp
*
2
dp_new
=
datapoint
s
.
wrap
(
output
,
like
=
dp
)
dp_new
=
tv_tensor
s
.
wrap
(
output
,
like
=
dp
)
assert
type
(
dp_new
)
is
type
(
dp
)
assert
type
(
dp_new
)
is
type
(
dp
)
assert
dp_new
.
data_ptr
()
==
output
.
data_ptr
()
assert
dp_new
.
data_ptr
()
==
output
.
data_ptr
()
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"op"
,
"op"
,
(
(
...
@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
...
@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
def
test_usual_operations
(
make_input
,
return_type
,
op
):
def
test_usual_operations
(
make_input
,
return_type
,
op
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
out
=
op
(
dp
)
out
=
op
(
dp
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
if
isinstance
(
dp
,
datapoint
s
.
BoundingBoxes
)
and
return_type
==
"
datapoint
"
:
if
isinstance
(
dp
,
tv_tensor
s
.
BoundingBoxes
)
and
return_type
==
"
tv_tensor
"
:
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"canvas_size"
)
assert
hasattr
(
out
,
"canvas_size"
)
...
@@ -286,22 +286,22 @@ def test_set_return_type():
...
@@ -286,22 +286,22 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
datapoint
s
.
set_return_type
(
"
datapoint
"
):
with
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
):
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
datapoint
s
.
set_return_type
(
"
datapoint
"
)
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
)
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
with
datapoint
s
.
set_return_type
(
"tensor"
):
with
tv_tensor
s
.
set_return_type
(
"tensor"
):
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
datapoint
s
.
set_return_type
(
"
datapoint
"
):
with
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
):
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global"
datapoint
s.set_return_type() was called within the context manager.
# regardless of whether the "global"
tv_tensor
s.set_return_type() was called within the context manager.
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
test/transforms_v2_dispatcher_infos.py
View file @
d5f4cc38
...
@@ -2,7 +2,7 @@ import collections.abc
...
@@ -2,7 +2,7 @@ import collections.abc
import
pytest
import
pytest
import
torchvision.transforms.v2.functional
as
F
import
torchvision.transforms.v2.functional
as
F
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
,
pad_xfail_jit_fill_condition
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
,
pad_xfail_jit_fill_condition
from
transforms_v2_legacy_utils
import
InfoBase
,
TestMark
from
transforms_v2_legacy_utils
import
InfoBase
,
TestMark
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self
.
pil_kernel_info
=
pil_kernel_info
self
.
pil_kernel_info
=
pil_kernel_info
kernel_infos
=
{}
kernel_infos
=
{}
for
datapoint
_type
,
kernel
in
self
.
kernels
.
items
():
for
tv_tensor
_type
,
kernel
in
self
.
kernels
.
items
():
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"Can't register
{
kernel
.
__name__
}
for type
{
datapoint
_type
}
since there is no `KernelInfo` for it. "
f
"Can't register
{
kernel
.
__name__
}
for type
{
tv_tensor
_type
}
since there is no `KernelInfo` for it. "
f
"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
f
"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
)
)
kernel_infos
[
datapoint
_type
]
=
kernel_info
kernel_infos
[
tv_tensor
_type
]
=
kernel_info
self
.
kernel_infos
=
kernel_infos
self
.
kernel_infos
=
kernel_infos
def
sample_inputs
(
self
,
*
datapoint
_types
,
filter_metadata
=
True
):
def
sample_inputs
(
self
,
*
tv_tensor
_types
,
filter_metadata
=
True
):
for
datapoint_type
in
datapoint
_types
or
self
.
kernel_infos
.
keys
():
for
tv_tensor_type
in
tv_tensor
_types
or
self
.
kernel_infos
.
keys
():
kernel_info
=
self
.
kernel_infos
.
get
(
datapoint
_type
)
kernel_info
=
self
.
kernel_infos
.
get
(
tv_tensor
_type
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
...
@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
...
@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
import
itertools
import
itertools
for
args_kwargs
in
sample_inputs
:
for
args_kwargs
in
sample_inputs
:
if
hasattr
(
datapoint
_type
,
"__annotations__"
):
if
hasattr
(
tv_tensor
_type
,
"__annotations__"
):
for
name
in
itertools
.
chain
(
for
name
in
itertools
.
chain
(
datapoint
_type
.
__annotations__
.
keys
(),
tv_tensor
_type
.
__annotations__
.
keys
(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
# per-dispatcher level. However, so far there is no option for that.
(
f
"old_
{
name
}
"
for
name
in
datapoint
_type
.
__annotations__
.
keys
()),
(
f
"old_
{
name
}
"
for
name
in
tv_tensor
_type
.
__annotations__
.
keys
()),
):
):
if
name
in
args_kwargs
.
kwargs
:
if
name
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
name
]
del
args_kwargs
.
kwargs
[
name
]
...
@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
...
@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)
)
skip_dispatch_
datapoint
=
TestMark
(
skip_dispatch_
tv_tensor
=
TestMark
(
(
"TestDispatchers"
,
"test_dispatch_
datapoint
"
),
(
"TestDispatchers"
,
"test_dispatch_
tv_tensor
"
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
tv_tensor
dispatch."
),
)
)
multi_crop_skips
=
[
multi_crop_skips
=
[
...
@@ -107,9 +107,9 @@ multi_crop_skips = [
...
@@ -107,9 +107,9 @@ multi_crop_skips = [
(
"TestDispatchers"
,
test_name
),
(
"TestDispatchers"
,
test_name
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
)
)
for
test_name
in
[
"test_pure_tensor_output_type"
,
"test_pil_output_type"
,
"test_
datapoint
_output_type"
]
for
test_name
in
[
"test_pure_tensor_output_type"
,
"test_pil_output_type"
,
"test_
tv_tensor
_output_type"
]
]
]
multi_crop_skips
.
append
(
skip_dispatch_
datapoint
)
multi_crop_skips
.
append
(
skip_dispatch_
tv_tensor
)
def
xfails_pil
(
reason
,
*
,
condition
=
None
):
def
xfails_pil
(
reason
,
*
,
condition
=
None
):
...
@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
...
@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
crop
,
F
.
crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
crop_image
,
tv_tensor
s
.
Image
:
F
.
crop_image
,
datapoint
s
.
Video
:
F
.
crop_video
,
tv_tensor
s
.
Video
:
F
.
crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
crop_mask
,
tv_tensor
s
.
Mask
:
F
.
crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
resized_crop
,
F
.
resized_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
resized_crop_image
,
tv_tensor
s
.
Image
:
F
.
resized_crop_image
,
datapoint
s
.
Video
:
F
.
resized_crop_video
,
tv_tensor
s
.
Video
:
F
.
resized_crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
resized_crop_mask
,
tv_tensor
s
.
Mask
:
F
.
resized_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_resized_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_resized_crop_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
pad
,
F
.
pad
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
pad_image
,
tv_tensor
s
.
Image
:
F
.
pad_image
,
datapoint
s
.
Video
:
F
.
pad_video
,
tv_tensor
s
.
Video
:
F
.
pad_video
,
datapoint
s
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
pad_mask
,
tv_tensor
s
.
Mask
:
F
.
pad_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
test_marks
=
[
test_marks
=
[
...
@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
...
@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
perspective
,
F
.
perspective
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
perspective_image
,
tv_tensor
s
.
Image
:
F
.
perspective_image
,
datapoint
s
.
Video
:
F
.
perspective_video
,
tv_tensor
s
.
Video
:
F
.
perspective_video
,
datapoint
s
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
perspective_mask
,
tv_tensor
s
.
Mask
:
F
.
perspective_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_perspective_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_perspective_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
...
@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
elastic
,
F
.
elastic
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
elastic_image
,
tv_tensor
s
.
Image
:
F
.
elastic_image
,
datapoint
s
.
Video
:
F
.
elastic_video
,
tv_tensor
s
.
Video
:
F
.
elastic_video
,
datapoint
s
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
elastic_mask
,
tv_tensor
s
.
Mask
:
F
.
elastic_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_elastic_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_elastic_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
...
@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
...
@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
center_crop
,
F
.
center_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
center_crop_image
,
tv_tensor
s
.
Image
:
F
.
center_crop_image
,
datapoint
s
.
Video
:
F
.
center_crop_video
,
tv_tensor
s
.
Video
:
F
.
center_crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
center_crop_mask
,
tv_tensor
s
.
Mask
:
F
.
center_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_center_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_center_crop_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
...
@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
gaussian_blur
,
F
.
gaussian_blur
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
gaussian_blur_image
,
tv_tensor
s
.
Image
:
F
.
gaussian_blur_image
,
datapoint
s
.
Video
:
F
.
gaussian_blur_video
,
tv_tensor
s
.
Video
:
F
.
gaussian_blur_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_gaussian_blur_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_gaussian_blur_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
...
@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
equalize
,
F
.
equalize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
equalize_image
,
tv_tensor
s
.
Image
:
F
.
equalize_image
,
datapoint
s
.
Video
:
F
.
equalize_video
,
tv_tensor
s
.
Video
:
F
.
equalize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
invert
,
F
.
invert
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
invert_image
,
tv_tensor
s
.
Image
:
F
.
invert_image
,
datapoint
s
.
Video
:
F
.
invert_video
,
tv_tensor
s
.
Video
:
F
.
invert_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
posterize
,
F
.
posterize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
posterize_image
,
tv_tensor
s
.
Image
:
F
.
posterize_image
,
datapoint
s
.
Video
:
F
.
posterize_video
,
tv_tensor
s
.
Video
:
F
.
posterize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
solarize
,
F
.
solarize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
solarize_image
,
tv_tensor
s
.
Image
:
F
.
solarize_image
,
datapoint
s
.
Video
:
F
.
solarize_video
,
tv_tensor
s
.
Video
:
F
.
solarize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
autocontrast
,
F
.
autocontrast
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
autocontrast_image
,
tv_tensor
s
.
Image
:
F
.
autocontrast_image
,
datapoint
s
.
Video
:
F
.
autocontrast_video
,
tv_tensor
s
.
Video
:
F
.
autocontrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_sharpness
,
F
.
adjust_sharpness
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_sharpness_image
,
tv_tensor
s
.
Image
:
F
.
adjust_sharpness_image
,
datapoint
s
.
Video
:
F
.
adjust_sharpness_video
,
tv_tensor
s
.
Video
:
F
.
adjust_sharpness_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
erase
,
F
.
erase
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
erase_image
,
tv_tensor
s
.
Image
:
F
.
erase_image
,
datapoint
s
.
Video
:
F
.
erase_video
,
tv_tensor
s
.
Video
:
F
.
erase_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_erase_image_pil
),
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_contrast
,
F
.
adjust_contrast
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_contrast_image
,
tv_tensor
s
.
Image
:
F
.
adjust_contrast_image
,
datapoint
s
.
Video
:
F
.
adjust_contrast_video
,
tv_tensor
s
.
Video
:
F
.
adjust_contrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_gamma
,
F
.
adjust_gamma
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_gamma_image
,
tv_tensor
s
.
Image
:
F
.
adjust_gamma_image
,
datapoint
s
.
Video
:
F
.
adjust_gamma_video
,
tv_tensor
s
.
Video
:
F
.
adjust_gamma_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_hue
,
F
.
adjust_hue
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_hue_image
,
tv_tensor
s
.
Image
:
F
.
adjust_hue_image
,
datapoint
s
.
Video
:
F
.
adjust_hue_video
,
tv_tensor
s
.
Video
:
F
.
adjust_hue_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_saturation
,
F
.
adjust_saturation
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_saturation_image
,
tv_tensor
s
.
Image
:
F
.
adjust_saturation_image
,
datapoint
s
.
Video
:
F
.
adjust_saturation_video
,
tv_tensor
s
.
Video
:
F
.
adjust_saturation_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
five_crop
,
F
.
five_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
five_crop_image
,
tv_tensor
s
.
Image
:
F
.
five_crop_image
,
datapoint
s
.
Video
:
F
.
five_crop_video
,
tv_tensor
s
.
Video
:
F
.
five_crop_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_five_crop_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
...
@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
ten_crop
,
F
.
ten_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
ten_crop_image
,
tv_tensor
s
.
Image
:
F
.
ten_crop_image
,
datapoint
s
.
Video
:
F
.
ten_crop_video
,
tv_tensor
s
.
Video
:
F
.
ten_crop_video
,
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
...
@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
...
@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
normalize
,
F
.
normalize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
normalize_image
,
tv_tensor
s
.
Image
:
F
.
normalize_image
,
datapoint
s
.
Video
:
F
.
normalize_video
,
tv_tensor
s
.
Video
:
F
.
normalize_video
,
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"mean"
),
xfail_jit_python_scalar_arg
(
"mean"
),
...
@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
...
@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
uniform_temporal_subsample
,
F
.
uniform_temporal_subsample
,
kernels
=
{
kernels
=
{
datapoint
s
.
Video
:
F
.
uniform_temporal_subsample_video
,
tv_tensor
s
.
Video
:
F
.
uniform_temporal_subsample_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
clamp_bounding_boxes
,
F
.
clamp_bounding_boxes
,
kernels
=
{
datapoint
s
.
BoundingBoxes
:
F
.
clamp_bounding_boxes
},
kernels
=
{
tv_tensor
s
.
BoundingBoxes
:
F
.
clamp_bounding_boxes
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
convert_bounding_box_format
,
F
.
convert_bounding_box_format
,
kernels
=
{
datapoint
s
.
BoundingBoxes
:
F
.
convert_bounding_box_format
},
kernels
=
{
tv_tensor
s
.
BoundingBoxes
:
F
.
convert_bounding_box_format
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
]
]
test/transforms_v2_kernel_infos.py
View file @
d5f4cc38
...
@@ -7,7 +7,7 @@ import pytest
...
@@ -7,7 +7,7 @@ import pytest
import
torch.testing
import
torch.testing
import
torchvision.ops
import
torchvision.ops
import
torchvision.transforms.v2.functional
as
F
import
torchvision.transforms.v2.functional
as
F
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
,
_parse_pad_padding
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
,
_parse_pad_padding
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
ArgsKwargs
,
ArgsKwargs
,
...
@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
bbox_xyxy
=
F
.
convert_bounding_box_format
(
bbox_xyxy
=
F
.
convert_bounding_box_format
(
bbox
.
as_subclass
(
torch
.
Tensor
),
bbox
.
as_subclass
(
torch
.
Tensor
),
old_format
=
format_
,
old_format
=
format_
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
,
inplace
=
True
,
)
)
points
=
np
.
array
(
points
=
np
.
array
(
...
@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
dtype
=
bbox_xyxy
.
dtype
,
dtype
=
bbox_xyxy
.
dtype
,
)
)
out_bbox
=
F
.
convert_bounding_box_format
(
out_bbox
=
F
.
convert_bounding_box_format
(
out_bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
out_bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
)
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox
=
F
.
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
)
out_bbox
=
F
.
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
)
...
@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
def
sample_inputs_convert_bounding_box_format
():
def
sample_inputs_convert_bounding_box_format
():
formats
=
list
(
datapoint
s
.
BoundingBoxFormat
)
formats
=
list
(
tv_tensor
s
.
BoundingBoxFormat
)
for
bounding_boxes_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
for
bounding_boxes_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
yield
ArgsKwargs
(
bounding_boxes_loader
,
old_format
=
bounding_boxes_loader
.
format
,
new_format
=
new_format
)
yield
ArgsKwargs
(
bounding_boxes_loader
,
old_format
=
bounding_boxes_loader
.
format
,
new_format
=
new_format
)
...
@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
...
@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
)
)
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
loader
=
make_bounding_box_loader
(
format
=
format
)
loader
=
make_bounding_box_loader
(
format
=
format
)
yield
ArgsKwargs
(
yield
ArgsKwargs
(
loader
,
format
=
format
,
canvas_size
=
loader
.
canvas_size
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
loader
,
format
=
format
,
canvas_size
=
loader
.
canvas_size
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
...
...
test/transforms_v2_legacy_utils.py
View file @
d5f4cc38
...
@@ -27,7 +27,7 @@ import PIL.Image
...
@@ -27,7 +27,7 @@ import PIL.Image
import
pytest
import
pytest
import
torch
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_pil_image
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_pil_image
...
@@ -82,7 +82,7 @@ def make_image(
...
@@ -82,7 +82,7 @@ def make_image(
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}:
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}:
data
[...,
-
1
,
:,
:]
=
max_value
data
[...,
-
1
,
:,
:]
=
max_value
return
datapoint
s
.
Image
(
data
)
return
tv_tensor
s
.
Image
(
data
)
def
make_image_tensor
(
*
args
,
**
kwargs
):
def
make_image_tensor
(
*
args
,
**
kwargs
):
...
@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
...
@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
def
make_bounding_boxes
(
def
make_bounding_boxes
(
canvas_size
=
DEFAULT_SIZE
,
canvas_size
=
DEFAULT_SIZE
,
*
,
*
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
batch_dims
=
(),
batch_dims
=
(),
dtype
=
None
,
dtype
=
None
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -107,12 +107,12 @@ def make_bounding_boxes(
...
@@ -107,12 +107,12 @@ def make_bounding_boxes(
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
flatten
().
tolist
()]).
reshape
(
values
.
shape
)
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
flatten
().
tolist
()]).
reshape
(
values
.
shape
)
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[
format
]
format
=
tv_tensor
s
.
BoundingBoxFormat
[
format
]
dtype
=
dtype
or
torch
.
float32
dtype
=
dtype
or
torch
.
float32
if
any
(
dim
==
0
for
dim
in
batch_dims
):
if
any
(
dim
==
0
for
dim
in
batch_dims
):
return
datapoint
s
.
BoundingBoxes
(
return
tv_tensor
s
.
BoundingBoxes
(
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
)
)
...
@@ -120,28 +120,28 @@ def make_bounding_boxes(
...
@@ -120,28 +120,28 @@ def make_bounding_boxes(
y
=
sample_position
(
h
,
canvas_size
[
0
])
y
=
sample_position
(
h
,
canvas_size
[
0
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
if
format
is
datapoint
s
.
BoundingBoxFormat
.
XYWH
:
if
format
is
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
:
parts
=
(
x
,
y
,
w
,
h
)
parts
=
(
x
,
y
,
w
,
h
)
elif
format
is
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
elif
format
is
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
x1
,
y1
=
x
,
y
x1
,
y1
=
x
,
y
x2
=
x1
+
w
x2
=
x1
+
w
y2
=
y1
+
h
y2
=
y1
+
h
parts
=
(
x1
,
y1
,
x2
,
y2
)
parts
=
(
x1
,
y1
,
x2
,
y2
)
elif
format
is
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
:
elif
format
is
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
:
cx
=
x
+
w
/
2
cx
=
x
+
w
/
2
cy
=
y
+
h
/
2
cy
=
y
+
h
/
2
parts
=
(
cx
,
cy
,
w
,
h
)
parts
=
(
cx
,
cy
,
w
,
h
)
else
:
else
:
raise
ValueError
(
f
"Format
{
format
}
is not supported"
)
raise
ValueError
(
f
"Format
{
format
}
is not supported"
)
return
datapoint
s
.
BoundingBoxes
(
return
tv_tensor
s
.
BoundingBoxes
(
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
)
)
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_objects
=
5
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_objects
=
5
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return
datapoint
s
.
Mask
(
return
tv_tensor
s
.
Mask
(
torch
.
testing
.
make_tensor
(
torch
.
testing
.
make_tensor
(
(
*
batch_dims
,
num_objects
,
*
size
),
(
*
batch_dims
,
num_objects
,
*
size
),
low
=
0
,
low
=
0
,
...
@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
...
@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def
make_segmentation_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_categories
=
10
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
def
make_segmentation_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_categories
=
10
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return
datapoint
s
.
Mask
(
return
tv_tensor
s
.
Mask
(
torch
.
testing
.
make_tensor
(
torch
.
testing
.
make_tensor
(
(
*
batch_dims
,
*
size
),
(
*
batch_dims
,
*
size
),
low
=
0
,
low
=
0
,
...
@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
...
@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def
make_video
(
size
=
DEFAULT_SIZE
,
*
,
num_frames
=
3
,
batch_dims
=
(),
**
kwargs
):
def
make_video
(
size
=
DEFAULT_SIZE
,
*
,
num_frames
=
3
,
batch_dims
=
(),
**
kwargs
):
return
datapoint
s
.
Video
(
make_image
(
size
,
batch_dims
=
(
*
batch_dims
,
num_frames
),
**
kwargs
))
return
tv_tensor
s
.
Video
(
make_image
(
size
,
batch_dims
=
(
*
batch_dims
,
num_frames
),
**
kwargs
))
def
make_video_tensor
(
*
args
,
**
kwargs
):
def
make_video_tensor
(
*
args
,
**
kwargs
):
...
@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
...
@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
return
datapoint
s
.
Image
(
image_tensor
)
return
tv_tensor
s
.
Image
(
image_tensor
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
...
@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
...
@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BoundingBoxesLoader
(
TensorLoader
):
class
BoundingBoxesLoader
(
TensorLoader
):
format
:
datapoint
s
.
BoundingBoxFormat
format
:
tv_tensor
s
.
BoundingBoxFormat
spatial_size
:
Tuple
[
int
,
int
]
spatial_size
:
Tuple
[
int
,
int
]
canvas_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
canvas_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
...
@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
...
@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
def
make_bounding_box_loader
(
*
,
extra_dims
=
(),
format
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtype
=
torch
.
float32
):
def
make_bounding_box_loader
(
*
,
extra_dims
=
(),
format
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtype
=
torch
.
float32
):
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[
format
]
format
=
tv_tensor
s
.
BoundingBoxFormat
[
format
]
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"spatial_size"
)
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"spatial_size"
)
...
@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
...
@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
def
make_bounding_box_loaders
(
def
make_bounding_box_loaders
(
*
,
*
,
extra_dims
=
tuple
(
d
for
d
in
DEFAULT_EXTRA_DIMS
if
len
(
d
)
<
2
),
extra_dims
=
tuple
(
d
for
d
in
DEFAULT_EXTRA_DIMS
if
len
(
d
)
<
2
),
formats
=
tuple
(
datapoint
s
.
BoundingBoxFormat
),
formats
=
tuple
(
tv_tensor
s
.
BoundingBoxFormat
),
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
):
):
...
...
torchvision/datasets/__init__.py
View file @
d5f4cc38
...
@@ -137,7 +137,7 @@ __all__ = (
...
@@ -137,7 +137,7 @@ __all__ = (
# Ref: https://peps.python.org/pep-0562/
# Ref: https://peps.python.org/pep-0562/
def
__getattr__
(
name
):
def
__getattr__
(
name
):
if
name
in
(
"wrap_dataset_for_transforms_v2"
,):
if
name
in
(
"wrap_dataset_for_transforms_v2"
,):
from
torchvision.
datapoint
s._dataset_wrapper
import
wrap_dataset_for_transforms_v2
from
torchvision.
tv_tensor
s._dataset_wrapper
import
wrap_dataset_for_transforms_v2
return
wrap_dataset_for_transforms_v2
return
wrap_dataset_for_transforms_v2
...
...
torchvision/prototype/__init__.py
View file @
d5f4cc38
from
.
import
datapoints
,
models
,
transforms
,
utils
from
.
import
models
,
transforms
,
tv_tensors
,
utils
torchvision/prototype/datasets/_builtin/caltech.py
View file @
d5f4cc38
...
@@ -6,8 +6,6 @@ import numpy as np
...
@@ -6,8 +6,6 @@ import numpy as np
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file
,
read_categories_file
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/celeba.py
View file @
d5f4cc38
...
@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
...
@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
path_accessor
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/cifar.py
View file @
d5f4cc38
...
@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
...
@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import
numpy
as
np
import
numpy
as
np
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.datapoints
import
Image
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
Image
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/clevr.py
View file @
d5f4cc38
...
@@ -2,7 +2,6 @@ import pathlib
...
@@ -2,7 +2,6 @@ import pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor
,
path_accessor
,
path_comparator
,
path_comparator
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/coco.py
View file @
d5f4cc38
...
@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
...
@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
UnBatcher
,
UnBatcher
,
)
)
from
torchvision.datapoints
import
BoundingBoxes
,
Mask
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor
,
path_accessor
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
,
Mask
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/country211.py
View file @
d5f4cc38
...
@@ -2,7 +2,6 @@ import pathlib
...
@@ -2,7 +2,6 @@ import pathlib
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/cub200.py
View file @
d5f4cc38
...
@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
...
@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file
,
read_categories_file
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/dtd.py
View file @
d5f4cc38
...
@@ -3,7 +3,6 @@ import pathlib
...
@@ -3,7 +3,6 @@ import pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchdata.datapipes.iter
import
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment