Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d5f4cc38
Unverified
Commit
d5f4cc38
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)
parent
b9447fdd
Changes
85
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
455 additions
and
455 deletions
+455
-455
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+15
-15
test/test_transforms_v2.py
test/test_transforms_v2.py
+57
-57
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+20
-20
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+67
-67
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+100
-100
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+32
-32
test/test_tv_tensors.py
test/test_tv_tensors.py
+52
-52
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+76
-76
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+5
-5
test/transforms_v2_legacy_utils.py
test/transforms_v2_legacy_utils.py
+16
-16
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+1
-1
torchvision/prototype/__init__.py
torchvision/prototype/__init__.py
+1
-1
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+2
-2
torchvision/prototype/datasets/_builtin/celeba.py
torchvision/prototype/datasets/_builtin/celeba.py
+2
-2
torchvision/prototype/datasets/_builtin/cifar.py
torchvision/prototype/datasets/_builtin/cifar.py
+2
-2
torchvision/prototype/datasets/_builtin/clevr.py
torchvision/prototype/datasets/_builtin/clevr.py
+1
-1
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+2
-2
torchvision/prototype/datasets/_builtin/country211.py
torchvision/prototype/datasets/_builtin/country211.py
+1
-1
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+2
-2
torchvision/prototype/datasets/_builtin/dtd.py
torchvision/prototype/datasets/_builtin/dtd.py
+1
-1
No files found.
test/test_prototype_transforms.py
View file @
d5f4cc38
...
@@ -7,11 +7,11 @@ import torch
...
@@ -7,11 +7,11 @@ import torch
from
common_utils
import
assert_equal
from
common_utils
import
assert_equal
from
prototype_common_utils
import
make_label
from
prototype_common_utils
import
make_label
from
torchvision.prototype
import
transforms
,
tv_tensors
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.tv_tensors
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
DEFAULT_EXTRA_DIMS
,
DEFAULT_EXTRA_DIMS
,
make_bounding_boxes
,
make_bounding_boxes
,
...
@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
...
@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
# images, batch size = 2
# images, batch size = 2
self
.
create_fake_image
(
mocker
,
Image
),
self
.
create_fake_image
(
mocker
,
Image
),
# labels, bboxes, masks
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
datapoint
s
.
Label
),
mocker
.
MagicMock
(
spec
=
tv_tensor
s
.
Label
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
# labels, bboxes, masks
...
@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
...
@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
transform
.
_extract_image_targets
(
flat_sample
)
transform
.
_extract_image_targets
(
flat_sample
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
[
Image
,
PIL
.
Image
.
Image
,
torch
.
Tensor
])
@
pytest
.
mark
.
parametrize
(
"image_type"
,
[
Image
,
PIL
.
Image
.
Image
,
torch
.
Tensor
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
datapoints
.
Label
,
datapoint
s
.
OneHotLabel
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensor
s
.
OneHotLabel
])
def
test__extract_image_targets
(
self
,
image_type
,
label_type
,
mocker
):
def
test__extract_image_targets
(
self
,
image_type
,
label_type
,
mocker
):
transform
=
transforms
.
SimpleCopyPaste
()
transform
=
transforms
.
SimpleCopyPaste
()
...
@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
...
@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
assert
isinstance
(
target
[
key
],
type_
)
assert
isinstance
(
target
[
key
],
type_
)
assert
target
[
key
]
in
flat_sample
assert
target
[
key
]
in
flat_sample
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
datapoints
.
Label
,
datapoint
s
.
OneHotLabel
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensor
s
.
OneHotLabel
])
def
test__copy_paste
(
self
,
label_type
):
def
test__copy_paste
(
self
,
label_type
):
image
=
2
*
torch
.
ones
(
3
,
32
,
32
)
image
=
2
*
torch
.
ones
(
3
,
32
,
32
)
masks
=
torch
.
zeros
(
2
,
32
,
32
)
masks
=
torch
.
zeros
(
2
,
32
,
32
)
...
@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
...
@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
blending
=
True
blending
=
True
resize_interpolation
=
InterpolationMode
.
BILINEAR
resize_interpolation
=
InterpolationMode
.
BILINEAR
antialias
=
None
antialias
=
None
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
target
=
{
target
=
{
"boxes"
:
BoundingBoxes
(
"boxes"
:
BoundingBoxes
(
...
@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
...
@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
paste_masks
[
0
,
13
:
19
,
12
:
18
]
=
1
paste_masks
[
0
,
13
:
19
,
12
:
18
]
=
1
paste_masks
[
1
,
15
:
19
,
1
:
8
]
=
1
paste_masks
[
1
,
15
:
19
,
1
:
8
]
=
1
paste_labels
=
torch
.
tensor
([
3
,
4
])
paste_labels
=
torch
.
tensor
([
3
,
4
])
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_target
=
{
paste_target
=
{
"boxes"
:
BoundingBoxes
(
"boxes"
:
BoundingBoxes
(
...
@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
...
@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][
2
:,
:],
paste_target
[
"boxes"
])
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][
2
:,
:],
paste_target
[
"boxes"
])
expected_labels
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
expected_labels
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
if
label_type
==
datapoint
s
.
OneHotLabel
:
if
label_type
==
tv_tensor
s
.
OneHotLabel
:
expected_labels
=
torch
.
nn
.
functional
.
one_hot
(
expected_labels
,
num_classes
=
5
)
expected_labels
=
torch
.
nn
.
functional
.
one_hot
(
expected_labels
,
num_classes
=
5
)
torch
.
testing
.
assert_close
(
output_target
[
"labels"
],
label_type
(
expected_labels
))
torch
.
testing
.
assert_close
(
output_target
[
"labels"
],
label_type
(
expected_labels
))
...
@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
...
@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
class
TestLabelToOneHot
:
class
TestLabelToOneHot
:
def
test__transform
(
self
):
def
test__transform
(
self
):
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
labels
=
datapoint
s
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
labels
=
tv_tensor
s
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
transform
=
transforms
.
LabelToOneHot
()
transform
=
transforms
.
LabelToOneHot
()
ohe_labels
=
transform
(
labels
)
ohe_labels
=
transform
(
labels
)
assert
isinstance
(
ohe_labels
,
datapoint
s
.
OneHotLabel
)
assert
isinstance
(
ohe_labels
,
tv_tensor
s
.
OneHotLabel
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
...
@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
...
@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
def
test_fixed_sized_crop_against_detection_reference
():
def
test_fixed_sized_crop_against_detection_reference
():
def
make_
datapoint
s
():
def
make_
tv_tensor
s
():
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
...
@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
...
@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
)
target
=
{
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"masks"
:
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
),
"masks"
:
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
),
}
}
yield
(
datapoint
_image
,
target
)
yield
(
tv_tensor
_image
,
target
)
t
=
transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t
=
transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t_ref
=
det_transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t_ref
=
det_transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
for
dp
in
make_
datapoint
s
():
for
dp
in
make_
tv_tensor
s
():
# We should use prototype transform first as reference transform performs inplace target update
# We should use prototype transform first as reference transform performs inplace target update
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
output
=
t
(
dp
)
output
=
t
(
dp
)
...
...
test/test_transforms_v2.py
View file @
d5f4cc38
...
@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
...
@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
...
@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
...
@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
adapted_input
=
{}
adapted_input
=
{}
image_or_video_found
=
False
image_or_video_found
=
False
for
key
,
value
in
input
.
items
():
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)):
if
isinstance
(
value
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)):
# AA transforms don't support bounding boxes or masks
# AA transforms don't support bounding boxes or masks
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoint
s
.
Video
,
is_pure_tensor
,
PIL
.
Image
.
Image
)):
elif
check_type
(
value
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
,
is_pure_tensor
,
PIL
.
Image
.
Image
)):
if
image_or_video_found
:
if
image_or_video_found
:
# AA transforms only support a single image or video
# AA transforms only support a single image or video
continue
continue
...
@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
...
@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
# normalize doesn't support PIL images
# normalize doesn't support PIL images
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoint
s
.
Video
,
is_pure_tensor
)):
elif
check_type
(
value
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
,
is_pure_tensor
)):
# normalize doesn't support integer images
# normalize doesn't support integer images
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
adapted_input
[
key
]
=
value
adapted_input
[
key
]
=
value
...
@@ -142,7 +142,7 @@ class TestSmoke:
...
@@ -142,7 +142,7 @@ class TestSmoke:
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertImageDtype
(),
None
),
(
transforms
.
ConvertImageDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
(
...
@@ -178,19 +178,19 @@ class TestSmoke:
...
@@ -178,19 +178,19 @@ class TestSmoke:
canvas_size
=
F
.
get_size
(
image_or_video
)
canvas_size
=
F
.
get_size
(
image_or_video
)
input
=
dict
(
input
=
dict
(
image_or_video
=
image_or_video
,
image_or_video
=
image_or_video
,
image_
datapoint
=
make_image
(
size
=
canvas_size
),
image_
tv_tensor
=
make_image
(
size
=
canvas_size
),
video_
datapoint
=
make_video
(
size
=
canvas_size
),
video_
tv_tensor
=
make_video
(
size
=
canvas_size
),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
canvas_size
],
color_spaces
=
[
"RGB"
])),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
canvas_size
],
color_spaces
=
[
"RGB"
])),
bounding_boxes_xyxy
=
make_bounding_boxes
(
bounding_boxes_xyxy
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
3
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
3
,)
),
),
bounding_boxes_xywh
=
make_bounding_boxes
(
bounding_boxes_xywh
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
4
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
4
,)
),
),
bounding_boxes_cxcywh
=
make_bounding_boxes
(
bounding_boxes_cxcywh
=
make_bounding_boxes
(
format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
5
,)
format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
batch_dims
=
(
5
,)
),
),
bounding_boxes_degenerate_xyxy
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_xyxy
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -199,10 +199,10 @@ class TestSmoke:
...
@@ -199,10 +199,10 @@ class TestSmoke:
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
bounding_boxes_degenerate_xywh
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_xywh
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -211,10 +211,10 @@ class TestSmoke:
...
@@ -211,10 +211,10 @@ class TestSmoke:
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
bounding_boxes_degenerate_cxcywh
=
datapoint
s
.
BoundingBoxes
(
bounding_boxes_degenerate_cxcywh
=
tv_tensor
s
.
BoundingBoxes
(
[
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
0
,
1
],
# no height
...
@@ -223,7 +223,7 @@ class TestSmoke:
...
@@ -223,7 +223,7 @@ class TestSmoke:
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
],
format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
,
canvas_size
=
canvas_size
,
canvas_size
=
canvas_size
,
),
),
detection_mask
=
make_detection_mask
(
size
=
canvas_size
),
detection_mask
=
make_detection_mask
(
size
=
canvas_size
),
...
@@ -262,7 +262,7 @@ class TestSmoke:
...
@@ -262,7 +262,7 @@ class TestSmoke:
else
:
else
:
assert
output_item
is
input_item
assert
output_item
is
input_item
if
isinstance
(
input_item
,
datapoint
s
.
BoundingBoxes
)
and
not
isinstance
(
if
isinstance
(
input_item
,
tv_tensor
s
.
BoundingBoxes
)
and
not
isinstance
(
transform
,
transforms
.
ConvertBoundingBoxFormat
transform
,
transforms
.
ConvertBoundingBoxFormat
):
):
assert
output_item
.
format
==
input_item
.
format
assert
output_item
.
format
==
input_item
.
format
...
@@ -270,9 +270,9 @@ class TestSmoke:
...
@@ -270,9 +270,9 @@ class TestSmoke:
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one.
# transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above
# TODO: we should test that against all degenerate boxes above
for
format
in
list
(
datapoint
s
.
BoundingBoxFormat
):
for
format
in
list
(
tv_tensor
s
.
BoundingBoxFormat
):
sample
=
dict
(
sample
=
dict
(
boxes
=
datapoint
s
.
BoundingBoxes
([[
0
,
0
,
0
,
0
]],
format
=
format
,
canvas_size
=
(
224
,
244
)),
boxes
=
tv_tensor
s
.
BoundingBoxes
([[
0
,
0
,
0
,
0
]],
format
=
format
,
canvas_size
=
(
224
,
244
)),
labels
=
torch
.
tensor
([
3
]),
labels
=
torch
.
tensor
([
3
]),
)
)
assert
transforms
.
SanitizeBoundingBoxes
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
assert
transforms
.
SanitizeBoundingBoxes
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
...
@@ -652,7 +652,7 @@ class TestRandomErasing:
...
@@ -652,7 +652,7 @@ class TestRandomErasing:
class
TestTransform
:
class
TestTransform
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test_check_transformed_types
(
self
,
inpt_type
,
mocker
):
def
test_check_transformed_types
(
self
,
inpt_type
,
mocker
):
# This test ensures that we correctly handle which types to transform and which to bypass
# This test ensures that we correctly handle which types to transform and which to bypass
...
@@ -670,7 +670,7 @@ class TestTransform:
...
@@ -670,7 +670,7 @@ class TestTransform:
class
TestToImage
:
class
TestToImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
fn
=
mocker
.
patch
(
...
@@ -681,7 +681,7 @@ class TestToImage:
...
@@ -681,7 +681,7 @@ class TestToImage:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImage
()
transform
=
transforms
.
ToImage
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Image
,
str
,
int
):
if
inpt_type
in
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
...
@@ -690,7 +690,7 @@ class TestToImage:
...
@@ -690,7 +690,7 @@ class TestToImage:
class
TestToPILImage
:
class
TestToPILImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_pil_image"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_pil_image"
)
...
@@ -698,7 +698,7 @@ class TestToPILImage:
...
@@ -698,7 +698,7 @@ class TestToPILImage:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToPILImage
()
transform
=
transforms
.
ToPILImage
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
PIL
.
Image
.
Image
,
datapoint
s
.
BoundingBoxes
,
str
,
int
):
if
inpt_type
in
(
PIL
.
Image
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
...
@@ -707,7 +707,7 @@ class TestToPILImage:
...
@@ -707,7 +707,7 @@ class TestToPILImage:
class
TestToTensor
:
class
TestToTensor
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
np
.
ndarray
,
datapoint
s
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
np
.
ndarray
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.functional.to_tensor"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.functional.to_tensor"
)
...
@@ -716,7 +716,7 @@ class TestToTensor:
...
@@ -716,7 +716,7 @@ class TestToTensor:
with
pytest
.
warns
(
UserWarning
,
match
=
"deprecated and will be removed"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"deprecated and will be removed"
):
transform
=
transforms
.
ToTensor
()
transform
=
transforms
.
ToTensor
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
datapoint
s
.
Image
,
torch
.
Tensor
,
datapoint
s
.
BoundingBoxes
,
str
,
int
):
if
inpt_type
in
(
tv_tensor
s
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
BoundingBoxes
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
...
@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
...
@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
def
test__get_params
(
self
,
device
,
options
):
def
test__get_params
(
self
,
device
,
options
):
orig_h
,
orig_w
=
size
=
(
24
,
32
)
orig_h
,
orig_w
=
size
=
(
24
,
32
)
image
=
make_image
(
size
)
image
=
make_image
(
size
)
bboxes
=
datapoint
s
.
BoundingBoxes
(
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
10
,
10
],
[
20
,
20
,
23
,
23
],
[
1
,
20
,
10
,
23
],
[
20
,
1
,
23
,
10
]]),
torch
.
tensor
([[
1
,
1
,
10
,
10
],
[
20
,
20
,
23
,
23
],
[
1
,
20
,
10
,
23
],
[
20
,
1
,
23
,
10
]]),
format
=
"XYXY"
,
format
=
"XYXY"
,
canvas_size
=
size
,
canvas_size
=
size
,
...
@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
...
@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
def
test__transform_empty_params
(
self
,
mocker
):
def
test__transform_empty_params
(
self
,
mocker
):
transform
=
transforms
.
RandomIoUCrop
(
sampler_options
=
[
2.0
])
transform
=
transforms
.
RandomIoUCrop
(
sampler_options
=
[
2.0
])
image
=
datapoint
s
.
Image
(
torch
.
rand
(
1
,
3
,
4
,
4
))
image
=
tv_tensor
s
.
Image
(
torch
.
rand
(
1
,
3
,
4
,
4
))
bboxes
=
datapoint
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
canvas_size
=
(
4
,
4
))
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
canvas_size
=
(
4
,
4
))
label
=
torch
.
tensor
([
1
])
label
=
torch
.
tensor
([
1
])
sample
=
[
image
,
bboxes
,
label
]
sample
=
[
image
,
bboxes
,
label
]
# Let's mock transform._get_params to control the output:
# Let's mock transform._get_params to control the output:
...
@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
...
@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels:
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
output_bboxes
,
tv_tensor
s
.
BoundingBoxes
)
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
output_masks
=
output
[
2
]
output_masks
=
output
[
2
]
assert
isinstance
(
output_masks
,
datapoint
s
.
Mask
)
assert
isinstance
(
output_masks
,
tv_tensor
s
.
Mask
)
class
TestScaleJitter
:
class
TestScaleJitter
:
...
@@ -899,7 +899,7 @@ class TestLinearTransformation:
...
@@ -899,7 +899,7 @@ class TestLinearTransformation:
[
[
122
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122.0
*
torch
.
ones
(
1
,
3
,
8
,
8
),
122.0
*
torch
.
ones
(
1
,
3
,
8
,
8
),
datapoint
s
.
Image
(
122
*
torch
.
ones
(
1
,
3
,
8
,
8
)),
tv_tensor
s
.
Image
(
122
*
torch
.
ones
(
1
,
3
,
8
,
8
)),
PIL
.
Image
.
new
(
"RGB"
,
(
8
,
8
),
(
122
,
122
,
122
)),
PIL
.
Image
.
new
(
"RGB"
,
(
8
,
8
),
(
122
,
122
,
122
)),
],
],
)
)
...
@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
...
@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
[
[
torch
.
zeros
(
10
,
3
,
8
,
8
),
torch
.
zeros
(
10
,
3
,
8
,
8
),
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
),
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
),
datapoint
s
.
Video
(
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
)),
tv_tensor
s
.
Video
(
torch
.
zeros
(
1
,
10
,
3
,
8
,
8
)),
],
],
)
)
def
test__transform
(
self
,
inpt
):
def
test__transform
(
self
,
inpt
):
...
@@ -971,12 +971,12 @@ def test_antialias_warning():
...
@@ -971,12 +971,12 @@ def test_antialias_warning():
transforms
.
RandomResize
(
10
,
20
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
)(
tensor_img
)
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resized_crop
(
datapoint
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
tv_tensor
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resize
(
datapoint
s
.
Video
(
tensor_video
),
(
20
,
20
))
F
.
resize
(
tv_tensor
s
.
Video
(
tensor_video
),
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
F
.
resized_crop
(
datapoint
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
tv_tensor
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
))
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
warnings
.
simplefilter
(
"error"
)
...
@@ -990,17 +990,17 @@ def test_antialias_warning():
...
@@ -990,17 +990,17 @@ def test_antialias_warning():
transforms
.
RandomShortestSize
((
20
,
20
),
antialias
=
True
)(
tensor_img
)
transforms
.
RandomShortestSize
((
20
,
20
),
antialias
=
True
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
,
antialias
=
True
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
,
antialias
=
True
)(
tensor_img
)
F
.
resized_crop
(
datapoint
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
tv_tensor
s
.
Image
(
tensor_img
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
datapoint
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
tv_tensor
s
.
Video
(
tensor_video
),
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoint
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
image
=
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
image
=
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
...
@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
...
@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
assert
out_label
==
label
assert
out_label
==
label
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoint
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
tv_tensor
s
.
Image
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
...
@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# leaving FixedSizeCrop in prototype for now, and it expects Label
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {
datapoint
s.Mask: 0})
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {
tv_tensor
s.Mask: 0})
# ),
# ),
transforms
.
RandomCrop
((
1024
,
1024
),
pad_if_needed
=
True
),
transforms
.
RandomCrop
((
1024
,
1024
),
pad_if_needed
=
True
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
...
@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif
data_augmentation
==
"ssd"
:
elif
data_augmentation
==
"ssd"
:
t
=
[
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
datapoint
s
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
tv_tensor
s
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
,
to_tensor
,
...
@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
num_boxes
=
5
num_boxes
=
5
H
=
W
=
250
H
=
W
=
250
image
=
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
))
image
=
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
...
@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
datapoint
s
.
BoundingBoxes
(
boxes
,
format
=
"XYXY"
,
canvas_size
=
(
H
,
W
))
boxes
=
tv_tensor
s
.
BoundingBoxes
(
boxes
,
format
=
"XYXY"
,
canvas_size
=
(
H
,
W
))
masks
=
datapoint
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
masks
=
tv_tensor
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
sample
=
{
sample
=
{
"image"
:
image
,
"image"
:
image
,
...
@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out
=
t
(
sample
)
out
=
t
(
sample
)
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
datapoint
s
.
Image
:
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
tv_tensor
s
.
Image
:
assert
is_pure_tensor
(
out
[
"image"
])
assert
is_pure_tensor
(
out
[
"image"
])
else
:
else
:
assert
isinstance
(
out
[
"image"
],
datapoint
s
.
Image
)
assert
isinstance
(
out
[
"image"
],
tv_tensor
s
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
num_boxes_expected
=
{
num_boxes_expected
=
{
...
@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...
@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes
=
torch
.
tensor
(
boxes
)
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
boxes
=
datapoint
s
.
BoundingBoxes
(
boxes
=
tv_tensor
s
.
BoundingBoxes
(
boxes
,
boxes
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
H
,
W
),
canvas_size
=
(
H
,
W
),
)
)
masks
=
datapoint
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
masks
=
tv_tensor
s
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
whatever
=
torch
.
rand
(
10
)
whatever
=
torch
.
rand
(
10
)
input_img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
input_img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
sample
=
{
sample
=
{
...
@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...
@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert
out_image
is
input_img
assert
out_image
is
input_img
assert
out_whatever
is
whatever
assert
out_whatever
is
whatever
assert
isinstance
(
out_boxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
out_boxes
,
tv_tensor
s
.
BoundingBoxes
)
assert
isinstance
(
out_masks
,
datapoint
s
.
Mask
)
assert
isinstance
(
out_masks
,
tv_tensor
s
.
Mask
)
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
assert
out_labels
is
labels
assert
out_labels
is
labels
...
@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
...
@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
transforms
.
SanitizeBoundingBoxes
()(
img
,
boxes
)
transforms
.
SanitizeBoundingBoxes
()(
img
,
boxes
)
out_img
,
out_boxes
=
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
None
)(
img
,
boxes
)
out_img
,
out_boxes
=
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
None
)(
img
,
boxes
)
assert
isinstance
(
out_img
,
datapoint
s
.
Image
)
assert
isinstance
(
out_img
,
tv_tensor
s
.
Image
)
assert
isinstance
(
out_boxes
,
datapoint
s
.
BoundingBoxes
)
assert
isinstance
(
out_boxes
,
tv_tensor
s
.
BoundingBoxes
)
def
test_sanitize_bounding_boxes_errors
():
def
test_sanitize_bounding_boxes_errors
():
good_bbox
=
datapoint
s
.
BoundingBoxes
(
good_bbox
=
tv_tensor
s
.
BoundingBoxes
(
[[
0
,
0
,
10
,
10
]],
[[
0
,
0
,
10
,
10
]],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
20
,
20
),
canvas_size
=
(
20
,
20
),
)
)
...
...
test/test_transforms_v2_consistency.py
View file @
d5f4cc38
...
@@ -13,7 +13,7 @@ import torch
...
@@ -13,7 +13,7 @@ import torch
import
torchvision.transforms.v2
as
v2_transforms
import
torchvision.transforms.v2
as
v2_transforms
from
common_utils
import
assert_close
,
assert_equal
,
set_rng_seed
from
common_utils
import
assert_close
,
assert_equal
,
set_rng_seed
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
datapoints
,
transforms
as
legacy_transforms
from
torchvision
import
transforms
as
legacy_transforms
,
tv_tensors
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
...
@@ -478,15 +478,15 @@ def check_call_consistency(
...
@@ -478,15 +478,15 @@ def check_call_consistency(
output_prototype_image
=
prototype_transform
(
image
)
output_prototype_image
=
prototype_transform
(
image
)
except
Exception
as
exc
:
except
Exception
as
exc
:
raise
AssertionError
(
raise
AssertionError
(
f
"Transforming a image
datapoint
with shape
{
image_repr
}
failed in the prototype transform with "
f
"Transforming a image
tv_tensor
with shape
{
image_repr
}
failed in the prototype transform with "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"`
datapoint
s.Image` path in `_transform`."
f
"`
tv_tensor
s.Image` path in `_transform`."
)
from
exc
)
from
exc
assert_close
(
assert_close
(
output_prototype_image
,
output_prototype_image
,
output_prototype_tensor
,
output_prototype_tensor
,
msg
=
lambda
msg
:
f
"Output for
datapoint
and tensor images is not equal:
\n\n
{
msg
}
"
,
msg
=
lambda
msg
:
f
"Output for
tv_tensor
and tensor images is not equal:
\n\n
{
msg
}
"
,
**
closeness_kwargs
,
**
closeness_kwargs
,
)
)
...
@@ -747,7 +747,7 @@ class TestAATransforms:
...
@@ -747,7 +747,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -812,7 +812,7 @@ class TestAATransforms:
...
@@ -812,7 +812,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -887,7 +887,7 @@ class TestAATransforms:
...
@@ -887,7 +887,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -964,7 +964,7 @@ class TestAATransforms:
...
@@ -964,7 +964,7 @@ class TestAATransforms:
[
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
datapoint
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
tv_tensor
s
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
...
@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
class
TestRefDetTransforms
:
class
TestRefDetTransforms
:
def
make_
datapoint
s
(
self
,
with_mask
=
True
):
def
make_
tv_tensor
s
(
self
,
with_mask
=
True
):
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
...
@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
...
@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
torch
.
float32
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
torch
.
float32
)
target
=
{
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
...
@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
if
with_mask
:
if
with_mask
:
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
target
[
"masks"
]
=
make_detection_mask
(
size
=
size
,
num_objects
=
num_objects
,
dtype
=
torch
.
long
)
yield
(
datapoint
_image
,
target
)
yield
(
tv_tensor
_image
,
target
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"t_ref, t, data_kwargs"
,
"t_ref, t, data_kwargs"
,
...
@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
...
@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
],
],
)
)
def
test_transform
(
self
,
t_ref
,
t
,
data_kwargs
):
def
test_transform
(
self
,
t_ref
,
t
,
data_kwargs
):
for
dp
in
self
.
make_
datapoint
s
(
**
data_kwargs
):
for
dp
in
self
.
make_
tv_tensor
s
(
**
data_kwargs
):
# We should use prototype transform first as reference transform performs inplace target update
# We should use prototype transform first as reference transform performs inplace target update
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
...
@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
...
@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
class
TestRefSegTransforms
:
class
TestRefSegTransforms
:
def
make_
datapoint
s
(
self
,
supports_pil
=
True
,
image_dtype
=
torch
.
uint8
):
def
make_
tv_tensor
s
(
self
,
supports_pil
=
True
,
image_dtype
=
torch
.
uint8
):
size
=
(
256
,
460
)
size
=
(
256
,
460
)
num_categories
=
21
num_categories
=
21
...
@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
...
@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
datapoint
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
image_dtype
)
tv_tensor
_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
,
dtype
=
image_dtype
)
datapoint
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
tv_tensor
_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint
_mask
)
dp
=
(
conv_fn
(
tv_tensor_image
),
tv_tensor
_mask
)
dp_ref
=
(
dp_ref
=
(
to_pil_image
(
datapoint
_image
)
if
supports_pil
else
datapoint
_image
.
as_subclass
(
torch
.
Tensor
),
to_pil_image
(
tv_tensor
_image
)
if
supports_pil
else
tv_tensor
_image
.
as_subclass
(
torch
.
Tensor
),
to_pil_image
(
datapoint
_mask
),
to_pil_image
(
tv_tensor
_mask
),
)
)
yield
dp
,
dp_ref
yield
dp
,
dp_ref
...
@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
...
@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
def
check
(
self
,
t
,
t_ref
,
data_kwargs
=
None
):
def
check
(
self
,
t
,
t_ref
,
data_kwargs
=
None
):
for
dp
,
dp_ref
in
self
.
make_
datapoint
s
(
**
data_kwargs
or
dict
()):
for
dp
,
dp_ref
in
self
.
make_
tv_tensor
s
(
**
data_kwargs
or
dict
()):
self
.
set_seed
()
self
.
set_seed
()
actual
=
actual_image
,
actual_mask
=
t
(
dp
)
actual
=
actual_image
,
actual_mask
=
t
(
dp
)
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
seg_transforms
.
RandomCrop
(
size
=
480
),
seg_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
Compose
(
v2_transforms
.
Compose
(
[
[
PadIfSmaller
(
size
=
480
,
fill
=
{
datapoint
s
.
Mask
:
255
,
"others"
:
0
}),
PadIfSmaller
(
size
=
480
,
fill
=
{
tv_tensor
s
.
Mask
:
255
,
"others"
:
0
}),
v2_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
RandomCrop
(
size
=
480
),
]
]
),
),
...
...
test/test_transforms_v2_functional.py
View file @
d5f4cc38
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
from
common_utils
import
assert_close
,
cache
,
cpu_and_cuda
,
needs_cuda
,
set_rng_seed
from
common_utils
import
assert_close
,
cache
,
cpu_and_cuda
,
needs_cuda
,
set_rng_seed
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2._utils
import
is_pure_tensor
from
torchvision.transforms.v2._utils
import
is_pure_tensor
...
@@ -164,22 +164,22 @@ class TestKernels:
...
@@ -164,22 +164,22 @@ class TestKernels:
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
datapoint_type
=
datapoint
s
.
Image
if
is_pure_tensor
(
batched_input
)
else
type
(
batched_input
)
tv_tensor_type
=
tv_tensor
s
.
Image
if
is_pure_tensor
(
batched_input
)
else
type
(
batched_input
)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
# Everything to the left is considered a batch dimension.
data_dims
=
{
data_dims
=
{
datapoint
s
.
Image
:
3
,
tv_tensor
s
.
Image
:
3
,
datapoint
s
.
BoundingBoxes
:
1
,
tv_tensor
s
.
BoundingBoxes
:
1
,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
# common ground.
datapoint
s
.
Mask
:
2
,
tv_tensor
s
.
Mask
:
2
,
datapoint
s
.
Video
:
4
,
tv_tensor
s
.
Video
:
4
,
}.
get
(
datapoint
_type
)
}.
get
(
tv_tensor
_type
)
if
data_dims
is
None
:
if
data_dims
is
None
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"The number of data dimensions cannot be determined for input of type
{
datapoint
_type
.
__name__
}
."
f
"The number of data dimensions cannot be determined for input of type
{
tv_tensor
_type
.
__name__
}
."
)
from
None
)
from
None
elif
batched_input
.
ndim
<=
data_dims
:
elif
batched_input
.
ndim
<=
data_dims
:
pytest
.
skip
(
"Input is not batched."
)
pytest
.
skip
(
"Input is not batched."
)
...
@@ -305,8 +305,8 @@ def spy_on(mocker):
...
@@ -305,8 +305,8 @@ def spy_on(mocker):
class
TestDispatchers
:
class
TestDispatchers
:
image_sample_inputs
=
make_info_args_kwargs_parametrization
(
image_sample_inputs
=
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
datapoint
s
.
Image
in
info
.
kernels
],
[
info
for
info
in
DISPATCHER_INFOS
if
tv_tensor
s
.
Image
in
info
.
kernels
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
Image
),
)
)
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
...
@@ -328,8 +328,8 @@ class TestDispatchers:
...
@@ -328,8 +328,8 @@ class TestDispatchers:
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
def
test_scripted_smoke
(
self
,
info
,
args_kwargs
,
device
):
dispatcher
=
script
(
info
.
dispatcher
)
dispatcher
=
script
(
info
.
dispatcher
)
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
image_pure_tensor
=
torch
.
Tensor
(
image_
datapoint
)
image_pure_tensor
=
torch
.
Tensor
(
image_
tv_tensor
)
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
...
@@ -355,25 +355,25 @@ class TestDispatchers:
...
@@ -355,25 +355,25 @@ class TestDispatchers:
@
image_sample_inputs
@
image_sample_inputs
def
test_pure_tensor_output_type
(
self
,
info
,
args_kwargs
):
def
test_pure_tensor_output_type
(
self
,
info
,
args_kwargs
):
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_pure_tensor
=
image_
datapoint
.
as_subclass
(
torch
.
Tensor
)
image_pure_tensor
=
image_
tv_tensor
.
as_subclass
(
torch
.
Tensor
)
output
=
info
.
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_pure_tensor
,
*
other_args
,
**
kwargs
)
# We cannot use `isinstance` here since all
datapoint
s are instances of `torch.Tensor` as well
# We cannot use `isinstance` here since all
tv_tensor
s are instances of `torch.Tensor` as well
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
output
)
is
torch
.
Tensor
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
Image
),
)
)
def
test_pil_output_type
(
self
,
info
,
args_kwargs
):
def
test_pil_output_type
(
self
,
info
,
args_kwargs
):
(
image_
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_
datapoint
.
ndim
>
3
:
if
image_
tv_tensor
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_pil_image
(
image_
datapoint
)
image_pil
=
F
.
to_pil_image
(
image_
tv_tensor
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
...
@@ -383,38 +383,38 @@ class TestDispatchers:
...
@@ -383,38 +383,38 @@ class TestDispatchers:
DISPATCHER_INFOS
,
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
)
def
test_
datapoint
_output_type
(
self
,
info
,
args_kwargs
):
def
test_
tv_tensor
_output_type
(
self
,
info
,
args_kwargs
):
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
tv_tensor
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
output
=
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
tv_tensor
,
*
other_args
,
**
kwargs
)
assert
isinstance
(
output
,
type
(
datapoint
))
assert
isinstance
(
output
,
type
(
tv_tensor
))
if
isinstance
(
datapoint
,
datapoint
s
.
BoundingBoxes
)
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
:
if
isinstance
(
tv_tensor
,
tv_tensor
s
.
BoundingBoxes
)
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
:
assert
output
.
format
==
datapoint
.
format
assert
output
.
format
==
tv_tensor
.
format
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"dispatcher_info"
,
"
datapoint
_type"
,
"kernel_info"
),
(
"dispatcher_info"
,
"
tv_tensor
_type"
,
"kernel_info"
),
[
[
pytest
.
param
(
pytest
.
param
(
dispatcher_info
,
datapoint
_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
datapoint
_type
.
__name__
}
"
dispatcher_info
,
tv_tensor
_type
,
kernel_info
,
id
=
f
"
{
dispatcher_info
.
id
}
-
{
tv_tensor
_type
.
__name__
}
"
)
)
for
dispatcher_info
in
DISPATCHER_INFOS
for
dispatcher_info
in
DISPATCHER_INFOS
for
datapoint
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
for
tv_tensor
_type
,
kernel_info
in
dispatcher_info
.
kernel_infos
.
items
()
],
],
)
)
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
datapoint
_type
,
kernel_info
):
def
test_dispatcher_kernel_signatures_consistency
(
self
,
dispatcher_info
,
tv_tensor
_type
,
kernel_info
):
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_signature
=
inspect
.
signature
(
dispatcher_info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_signature
=
inspect
.
signature
(
kernel_info
.
kernel
)
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
# We filter out metadata that is implicitly passed to the dispatcher through the input
datapoint
, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input
tv_tensor
, but has to be
# explicitly passed to the kernel.
# explicitly passed to the kernel.
input_type
=
{
v
:
k
for
k
,
v
in
dispatcher_info
.
kernels
.
items
()}.
get
(
kernel_info
.
kernel
)
input_type
=
{
v
:
k
for
k
,
v
in
dispatcher_info
.
kernels
.
items
()}.
get
(
kernel_info
.
kernel
)
explicit_metadata
=
{
explicit_metadata
=
{
datapoint
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
tv_tensor
s
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
...
@@ -445,9 +445,9 @@ class TestDispatchers:
...
@@ -445,9 +445,9 @@ class TestDispatchers:
[
[
info
info
for
info
in
DISPATCHER_INFOS
for
info
in
DISPATCHER_INFOS
if
datapoint
s
.
BoundingBoxes
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
if
tv_tensor
s
.
BoundingBoxes
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_bounding_box_format
],
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoint
s
.
BoundingBoxes
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
tv_tensor
s
.
BoundingBoxes
),
)
)
def
test_bounding_boxes_format_consistency
(
self
,
info
,
args_kwargs
):
def
test_bounding_boxes_format_consistency
(
self
,
info
,
args_kwargs
):
(
bounding_boxes
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
bounding_boxes
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
...
@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
...
@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
"metadata"
,
"metadata"
,
[
[
dict
(),
dict
(),
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
canvas_size
=
(
1
,
1
)),
],
],
)
)
...
@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
...
@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"metadata"
,
"metadata"
,
[
[
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
1
,
1
)),
dict
(
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
1
,
1
)),
],
],
)
)
def
test_
datapoint
_explicit_metadata
(
self
,
metadata
):
def
test_
tv_tensor
_explicit_metadata
(
self
,
metadata
):
datapoint
=
next
(
make_multiple_bounding_boxes
())
tv_tensor
=
next
(
make_multiple_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` must not be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` must not be passed"
)):
F
.
clamp_bounding_boxes
(
datapoint
,
**
metadata
)
F
.
clamp_bounding_boxes
(
tv_tensor
,
**
metadata
)
class
TestConvertFormatBoundingBoxes
:
class
TestConvertFormatBoundingBoxes
:
...
@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
...
@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
(
"inpt"
,
"old_format"
),
(
"inpt"
,
"old_format"
),
[
[
(
next
(
make_multiple_bounding_boxes
()),
None
),
(
next
(
make_multiple_bounding_boxes
()),
None
),
(
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
),
datapoint
s
.
BoundingBoxFormat
.
XYXY
),
(
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
),
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
),
],
],
)
)
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
...
@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
...
@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
pure_tensor
=
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
pure_tensor
=
next
(
make_multiple_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
F
.
convert_bounding_box_format
(
pure_tensor
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
)
F
.
convert_bounding_box_format
(
pure_tensor
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
)
def
test_
datapoint
_explicit_metadata
(
self
):
def
test_
tv_tensor
_explicit_metadata
(
self
):
datapoint
=
next
(
make_multiple_bounding_boxes
())
tv_tensor
=
next
(
make_multiple_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
F
.
convert_bounding_box_format
(
F
.
convert_bounding_box_format
(
datapoint
,
old_format
=
datapoint
.
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
tv_tensor
,
old_format
=
tv_tensor
.
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
)
)
...
@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
...
@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
"format"
,
[
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
],
[
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"top, left, height, width, expected_bboxes"
,
"top, left, height, width, expected_bboxes"
,
...
@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
# out_box = denormalize_bbox(n_out_box, height, width)
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)
# expected_bboxes.append(out_box)
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
canvas_size
=
(
64
,
76
)
canvas_size
=
(
64
,
76
)
in_boxes
=
[
in_boxes
=
[
[
10.0
,
15.0
,
25.0
,
35.0
],
[
10.0
,
15.0
,
25.0
,
35.0
],
...
@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
[
45.0
,
46.0
,
56.0
,
62.0
],
[
45.0
,
46.0
,
56.0
,
62.0
],
]
]
in_boxes
=
torch
.
tensor
(
in_boxes
,
device
=
device
)
in_boxes
=
torch
.
tensor
(
in_boxes
,
device
=
device
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
format
)
expected_bboxes
=
clamp_bounding_boxes
(
expected_bboxes
=
clamp_bounding_boxes
(
datapoint
s
.
BoundingBoxes
(
expected_bboxes
,
format
=
"XYXY"
,
canvas_size
=
canvas_size
)
tv_tensor
s
.
BoundingBoxes
(
expected_bboxes
,
format
=
"XYXY"
,
canvas_size
=
canvas_size
)
).
tolist
()
).
tolist
()
output_boxes
,
output_canvas_size
=
F
.
crop_bounding_boxes
(
output_boxes
,
output_canvas_size
=
F
.
crop_bounding_boxes
(
...
@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
...
@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
canvas_size
[
1
],
canvas_size
[
1
],
)
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
canvas_size
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
canvas_size
)
...
@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
...
@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
"format"
,
[
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
],
[
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"top, left, height, width, size"
,
"top, left, height, width, size"
,
...
@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
...
@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
bbox
[
3
]
=
(
bbox
[
3
]
-
top_
)
*
size_
[
0
]
/
height_
bbox
[
3
]
=
(
bbox
[
3
]
-
top_
)
*
size_
[
0
]
/
height_
return
bbox
return
bbox
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
canvas_size
=
(
100
,
100
)
canvas_size
=
(
100
,
100
)
in_boxes
=
[
in_boxes
=
[
[
10.0
,
10.0
,
20.0
,
20.0
],
[
10.0
,
10.0
,
20.0
,
20.0
],
...
@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
...
@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
expected_bboxes
.
append
(
_compute_expected_bbox
(
list
(
in_box
),
top
,
left
,
height
,
width
,
size
))
expected_bboxes
.
append
(
_compute_expected_bbox
(
list
(
in_box
),
top
,
left
,
height
,
width
,
size
))
expected_bboxes
=
torch
.
tensor
(
expected_bboxes
,
device
=
device
)
expected_bboxes
=
torch
.
tensor
(
expected_bboxes
,
device
=
device
)
in_boxes
=
datapoint
s
.
BoundingBoxes
(
in_boxes
=
tv_tensor
s
.
BoundingBoxes
(
in_boxes
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
device
=
device
in_boxes
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
device
=
device
)
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_bounding_box_format
(
in_boxes
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
format
)
output_boxes
,
output_canvas_size
=
F
.
resized_crop_bounding_boxes
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
output_boxes
,
output_canvas_size
=
F
.
resized_crop_bounding_boxes
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
if
format
!=
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_bounding_box_format
(
output_boxes
,
format
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
size
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
size
)
...
@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
...
@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
bbox
=
(
bbox
=
(
bbox
.
clone
()
bbox
.
clone
()
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
else
convert_bounding_box_format
(
bbox
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
else
convert_bounding_box_format
(
bbox
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
)
)
bbox
[
0
::
2
]
+=
pad_left
bbox
[
0
::
2
]
+=
pad_left
bbox
[
1
::
2
]
+=
pad_up
bbox
[
1
::
2
]
+=
pad_up
bbox
=
convert_bounding_box_format
(
bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
bbox
=
convert_bounding_box_format
(
bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
if
bbox
.
dtype
!=
dtype
:
if
bbox
.
dtype
!=
dtype
:
# Temporary cast to original dtype
# Temporary cast to original dtype
# e.g. float32 -> int
# e.g. float32 -> int
...
@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
]
)
)
bbox_xyxy
=
convert_bounding_box_format
(
bbox
,
old_format
=
format_
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
bbox_xyxy
=
convert_bounding_box_format
(
bbox
,
old_format
=
format_
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
points
=
np
.
array
(
points
=
np
.
array
(
[
[
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
...
@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
)
)
out_bbox
=
torch
.
from_numpy
(
out_bbox
)
out_bbox
=
torch
.
from_numpy
(
out_bbox
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
out_bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
)
)
return
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
).
to
(
bbox
)
return
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
).
to
(
bbox
)
...
@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
output_size_
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
output_size_
):
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
bbox
=
convert_bounding_box_format
(
bbox
.
float
(),
format_
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
)
bbox
=
convert_bounding_box_format
(
bbox
.
float
(),
format_
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
)
if
len
(
output_size_
)
==
1
:
if
len
(
output_size_
)
==
1
:
output_size_
.
append
(
output_size_
[
-
1
])
output_size_
.
append
(
output_size_
[
-
1
])
...
@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
...
@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox
[
3
].
item
(),
bbox
[
3
].
item
(),
]
]
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
convert_bounding_box_format
(
out_bbox
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
output_size
)
out_bbox
=
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
output_size
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
...
@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
...
@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
torch
.
tensor
(
true_cv2_results
[
gt_key
]).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
torch
.
tensor
(
true_cv2_results
[
gt_key
]).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
)
)
image
=
datapoint
s
.
Image
(
tensor
)
image
=
tv_tensor
s
.
Image
(
tensor
)
out
=
fn
(
image
,
kernel_size
=
ksize
,
sigma
=
sigma
)
out
=
fn
(
image
,
kernel_size
=
ksize
,
sigma
=
sigma
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
msg
=
f
"
{
ksize
}
,
{
sigma
}
"
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
msg
=
f
"
{
ksize
}
,
{
sigma
}
"
)
...
...
test/test_transforms_v2_refactored.py
View file @
d5f4cc38
This diff is collapsed.
Click to expand it.
test/test_transforms_v2_utils.py
View file @
d5f4cc38
...
@@ -6,46 +6,46 @@ import torch
...
@@ -6,46 +6,46 @@ import torch
import
torchvision.transforms.v2._utils
import
torchvision.transforms.v2._utils
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.v2._utils
import
has_all
,
has_any
from
torchvision.transforms.v2._utils
import
has_all
,
has_any
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.functional
import
to_pil_image
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
BOUNDING_BOX
=
make_bounding_boxes
(
DEFAULT_SIZE
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
BOUNDING_BOX
=
make_bounding_boxes
(
DEFAULT_SIZE
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
MASK
=
make_detection_mask
(
DEFAULT_SIZE
)
MASK
=
make_detection_mask
(
DEFAULT_SIZE
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"sample"
,
"types"
,
"expected"
),
(
"sample"
,
"types"
,
"expected"
),
[
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
),
((
MASK
,),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
False
),
((
MASK
,),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoint
s
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
tv_tensor
s
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
),
((
IMAGE
,),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
),
(
(
(
torch
.
Tensor
(
IMAGE
),),
(
torch
.
Tensor
(
IMAGE
),),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
,
True
,
),
),
(
(
(
to_pil_image
(
IMAGE
),),
(
to_pil_image
(
IMAGE
),),
(
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
(
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_utils
.
is_pure_tensor
),
True
,
True
,
),
),
],
],
...
@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
...
@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"sample"
,
"types"
,
"expected"
),
(
"sample"
,
"types"
,
"expected"
),
[
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
True
,
True
,
),
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
False
),
(
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
(
datapoints
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)),),
(
lambda
obj
:
isinstance
(
obj
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)),),
True
,
True
,
),
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
...
...
test/test_
datapoint
s.py
→
test/test_
tv_tensor
s.py
View file @
d5f4cc38
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
common_utils
import
assert_equal
,
make_bounding_boxes
,
make_image
,
make_segmentation_mask
,
make_video
from
common_utils
import
assert_equal
,
make_bounding_boxes
,
make_image
,
make_segmentation_mask
,
make_video
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -13,40 +13,40 @@ def restore_tensor_return_type():
...
@@ -13,40 +13,40 @@ def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
# (at least at the time of writing...)
yield
yield
datapoint
s
.
set_return_type
(
"Tensor"
)
tv_tensor
s
.
set_return_type
(
"Tensor"
)
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
def
test_image_instance
(
data
):
def
test_image_instance
(
data
):
image
=
datapoint
s
.
Image
(
data
)
image
=
tv_tensor
s
.
Image
(
data
)
assert
isinstance
(
image
,
torch
.
Tensor
)
assert
isinstance
(
image
,
torch
.
Tensor
)
assert
image
.
ndim
==
3
and
image
.
shape
[
0
]
==
3
assert
image
.
ndim
==
3
and
image
.
shape
[
0
]
==
3
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
10
,
size
=
(
1
,
32
,
32
)),
Image
.
new
(
"L"
,
(
32
,
32
),
color
=
2
)])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
10
,
size
=
(
1
,
32
,
32
)),
Image
.
new
(
"L"
,
(
32
,
32
),
color
=
2
)])
def
test_mask_instance
(
data
):
def
test_mask_instance
(
data
):
mask
=
datapoint
s
.
Mask
(
data
)
mask
=
tv_tensor
s
.
Mask
(
data
)
assert
isinstance
(
mask
,
torch
.
Tensor
)
assert
isinstance
(
mask
,
torch
.
Tensor
)
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
[
"XYXY"
,
"CXCYWH"
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
datapoint
s
.
BoundingBoxFormat
.
XYWH
]
"format"
,
[
"XYXY"
,
"CXCYWH"
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
]
)
)
def
test_bbox_instance
(
data
,
format
):
def
test_bbox_instance
(
data
,
format
):
bboxes
=
datapoint
s
.
BoundingBoxes
(
data
,
format
=
format
,
canvas_size
=
(
32
,
32
))
bboxes
=
tv_tensor
s
.
BoundingBoxes
(
data
,
format
=
format
,
canvas_size
=
(
32
,
32
))
assert
isinstance
(
bboxes
,
torch
.
Tensor
)
assert
isinstance
(
bboxes
,
torch
.
Tensor
)
assert
bboxes
.
ndim
==
2
and
bboxes
.
shape
[
1
]
==
4
assert
bboxes
.
ndim
==
2
and
bboxes
.
shape
[
1
]
==
4
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[(
format
.
upper
())]
format
=
tv_tensor
s
.
BoundingBoxFormat
[(
format
.
upper
())]
assert
bboxes
.
format
==
format
assert
bboxes
.
format
==
format
def
test_bbox_dim_error
():
def
test_bbox_dim_error
():
data_3d
=
[[[
1
,
2
,
3
,
4
]]]
data_3d
=
[[[
1
,
2
,
3
,
4
]]]
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a 1D or 2D tensor, got 3D"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a 1D or 2D tensor, got 3D"
):
datapoint
s
.
BoundingBoxes
(
data_3d
,
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
))
tv_tensor
s
.
BoundingBoxes
(
data_3d
,
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -64,8 +64,8 @@ def test_bbox_dim_error():
...
@@ -64,8 +64,8 @@ def test_bbox_dim_error():
],
],
)
)
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
datapoint
=
datapoint
s
.
Image
(
data
,
requires_grad
=
input_requires_grad
)
tv_tensor
=
tv_tensor
s
.
Image
(
data
,
requires_grad
=
input_requires_grad
)
assert
datapoint
.
requires_grad
is
expected_requires_grad
assert
tv_tensor
.
requires_grad
is
expected_requires_grad
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
@@ -75,7 +75,7 @@ def test_isinstance(make_input):
...
@@ -75,7 +75,7 @@ def test_isinstance(make_input):
def
test_wrapping_no_copy
():
def
test_wrapping_no_copy
():
tensor
=
torch
.
rand
(
3
,
16
,
16
)
tensor
=
torch
.
rand
(
3
,
16
,
16
)
image
=
datapoint
s
.
Image
(
tensor
)
image
=
tv_tensor
s
.
Image
(
tensor
)
assert
image
.
data_ptr
()
==
tensor
.
data_ptr
()
assert
image
.
data_ptr
()
==
tensor
.
data_ptr
()
...
@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
...
@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_to_
datapoint
_reference
(
make_input
,
return_type
):
def
test_to_
tv_tensor
_reference
(
make_input
,
return_type
):
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
tensor_to
=
tensor
.
to
(
dp
)
tensor_to
=
tensor
.
to
(
dp
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_clone_wrapping
(
make_input
,
return_type
):
def
test_clone_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_clone
=
dp
.
clone
()
dp_clone
=
dp
.
clone
()
assert
type
(
dp_clone
)
is
type
(
dp
)
assert
type
(
dp_clone
)
is
type
(
dp
)
...
@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
...
@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
)
dp
=
make_input
(
dtype
=
torch
.
float
)
assert
not
dp
.
requires_grad
assert
not
dp
.
requires_grad
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_requires_grad
=
dp
.
requires_grad_
(
True
)
dp_requires_grad
=
dp
.
requires_grad_
(
True
)
assert
type
(
dp_requires_grad
)
is
type
(
dp
)
assert
type
(
dp_requires_grad
)
is
type
(
dp
)
...
@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
...
@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_detach_wrapping
(
make_input
,
return_type
):
def
test_detach_wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
dp_detached
=
dp
.
detach
()
dp_detached
=
dp
.
detach
()
assert
type
(
dp_detached
)
is
type
(
dp
)
assert
type
(
dp_detached
)
is
type
(
dp
)
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_force_subclass_with_metadata
(
return_type
):
def
test_force_subclass_with_metadata
(
return_type
):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and
datapoint
s with metadata
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and
tv_tensor
s with metadata
# Largely the same as above, we additionally check that the metadata is preserved
# Largely the same as above, we additionally check that the metadata is preserved
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
bbox
=
datapoint
s
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
bbox
=
tv_tensor
s
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
datapoint
s
.
set_return_type
(
return_type
)
tv_tensor
s
.
set_return_type
(
return_type
)
bbox
=
bbox
.
clone
()
bbox
=
bbox
.
clone
()
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
bbox
=
bbox
.
to
(
torch
.
float64
)
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
bbox
=
bbox
.
detach
()
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
bbox
.
requires_grad_
(
True
)
if
return_type
==
"
datapoint
"
:
if
return_type
==
"
tv_tensor
"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
assert
bbox
.
requires_grad
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
output
=
dp
*
2
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
...
@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
original_type
=
type
(
dp
)
original_type
=
type
(
dp
)
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
output
=
dp
.
add_
(
0
)
output
=
dp
.
add_
(
0
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
assert
type
(
dp
)
is
original_type
assert
type
(
dp
)
is
original_type
...
@@ -219,7 +219,7 @@ def test_wrap(make_input):
...
@@ -219,7 +219,7 @@ def test_wrap(make_input):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
output
=
dp
*
2
dp_new
=
datapoint
s
.
wrap
(
output
,
like
=
dp
)
dp_new
=
tv_tensor
s
.
wrap
(
output
,
like
=
dp
)
assert
type
(
dp_new
)
is
type
(
dp
)
assert
type
(
dp_new
)
is
type
(
dp
)
assert
dp_new
.
data_ptr
()
==
output
.
data_ptr
()
assert
dp_new
.
data_ptr
()
==
output
.
data_ptr
()
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
datapoint
"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_tensor
"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"op"
,
"op"
,
(
(
...
@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
...
@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
def
test_usual_operations
(
make_input
,
return_type
,
op
):
def
test_usual_operations
(
make_input
,
return_type
,
op
):
dp
=
make_input
()
dp
=
make_input
()
with
datapoint
s
.
set_return_type
(
return_type
):
with
tv_tensor
s
.
set_return_type
(
return_type
):
out
=
op
(
dp
)
out
=
op
(
dp
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
datapoint
"
else
torch
.
Tensor
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
tv_tensor
"
else
torch
.
Tensor
)
if
isinstance
(
dp
,
datapoint
s
.
BoundingBoxes
)
and
return_type
==
"
datapoint
"
:
if
isinstance
(
dp
,
tv_tensor
s
.
BoundingBoxes
)
and
return_type
==
"
tv_tensor
"
:
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"canvas_size"
)
assert
hasattr
(
out
,
"canvas_size"
)
...
@@ -286,22 +286,22 @@ def test_set_return_type():
...
@@ -286,22 +286,22 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
datapoint
s
.
set_return_type
(
"
datapoint
"
):
with
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
):
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
datapoint
s
.
set_return_type
(
"
datapoint
"
)
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
)
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
with
datapoint
s
.
set_return_type
(
"tensor"
):
with
tv_tensor
s
.
set_return_type
(
"tensor"
):
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
datapoint
s
.
set_return_type
(
"
datapoint
"
):
with
tv_tensor
s
.
set_return_type
(
"
tv_tensor
"
):
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global"
datapoint
s.set_return_type() was called within the context manager.
# regardless of whether the "global"
tv_tensor
s.set_return_type() was called within the context manager.
assert
type
(
img
+
3
)
is
datapoint
s
.
Image
assert
type
(
img
+
3
)
is
tv_tensor
s
.
Image
datapoint
s
.
set_return_type
(
"tensor"
)
tv_tensor
s
.
set_return_type
(
"tensor"
)
test/transforms_v2_dispatcher_infos.py
View file @
d5f4cc38
...
@@ -2,7 +2,7 @@ import collections.abc
...
@@ -2,7 +2,7 @@ import collections.abc
import
pytest
import
pytest
import
torchvision.transforms.v2.functional
as
F
import
torchvision.transforms.v2.functional
as
F
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
,
pad_xfail_jit_fill_condition
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
,
pad_xfail_jit_fill_condition
from
transforms_v2_legacy_utils
import
InfoBase
,
TestMark
from
transforms_v2_legacy_utils
import
InfoBase
,
TestMark
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
...
@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self
.
pil_kernel_info
=
pil_kernel_info
self
.
pil_kernel_info
=
pil_kernel_info
kernel_infos
=
{}
kernel_infos
=
{}
for
datapoint
_type
,
kernel
in
self
.
kernels
.
items
():
for
tv_tensor
_type
,
kernel
in
self
.
kernels
.
items
():
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
kernel_info
=
self
.
_KERNEL_INFO_MAP
.
get
(
kernel
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
raise
pytest
.
UsageError
(
f
"Can't register
{
kernel
.
__name__
}
for type
{
datapoint
_type
}
since there is no `KernelInfo` for it. "
f
"Can't register
{
kernel
.
__name__
}
for type
{
tv_tensor
_type
}
since there is no `KernelInfo` for it. "
f
"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
f
"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
)
)
kernel_infos
[
datapoint
_type
]
=
kernel_info
kernel_infos
[
tv_tensor
_type
]
=
kernel_info
self
.
kernel_infos
=
kernel_infos
self
.
kernel_infos
=
kernel_infos
def
sample_inputs
(
self
,
*
datapoint
_types
,
filter_metadata
=
True
):
def
sample_inputs
(
self
,
*
tv_tensor
_types
,
filter_metadata
=
True
):
for
datapoint_type
in
datapoint
_types
or
self
.
kernel_infos
.
keys
():
for
tv_tensor_type
in
tv_tensor
_types
or
self
.
kernel_infos
.
keys
():
kernel_info
=
self
.
kernel_infos
.
get
(
datapoint
_type
)
kernel_info
=
self
.
kernel_infos
.
get
(
tv_tensor
_type
)
if
not
kernel_info
:
if
not
kernel_info
:
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
raise
pytest
.
UsageError
(
f
"There is no kernel registered for type
{
type
.
__name__
}
"
)
...
@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
...
@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
import
itertools
import
itertools
for
args_kwargs
in
sample_inputs
:
for
args_kwargs
in
sample_inputs
:
if
hasattr
(
datapoint
_type
,
"__annotations__"
):
if
hasattr
(
tv_tensor
_type
,
"__annotations__"
):
for
name
in
itertools
.
chain
(
for
name
in
itertools
.
chain
(
datapoint
_type
.
__annotations__
.
keys
(),
tv_tensor
_type
.
__annotations__
.
keys
(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
# per-dispatcher level. However, so far there is no option for that.
(
f
"old_
{
name
}
"
for
name
in
datapoint
_type
.
__annotations__
.
keys
()),
(
f
"old_
{
name
}
"
for
name
in
tv_tensor
_type
.
__annotations__
.
keys
()),
):
):
if
name
in
args_kwargs
.
kwargs
:
if
name
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
name
]
del
args_kwargs
.
kwargs
[
name
]
...
@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
...
@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)
)
skip_dispatch_
datapoint
=
TestMark
(
skip_dispatch_
tv_tensor
=
TestMark
(
(
"TestDispatchers"
,
"test_dispatch_
datapoint
"
),
(
"TestDispatchers"
,
"test_dispatch_
tv_tensor
"
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
datapoint
dispatch."
),
pytest
.
mark
.
skip
(
reason
=
"Dispatcher doesn't support arbitrary
tv_tensor
dispatch."
),
)
)
multi_crop_skips
=
[
multi_crop_skips
=
[
...
@@ -107,9 +107,9 @@ multi_crop_skips = [
...
@@ -107,9 +107,9 @@ multi_crop_skips = [
(
"TestDispatchers"
,
test_name
),
(
"TestDispatchers"
,
test_name
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
)
)
for
test_name
in
[
"test_pure_tensor_output_type"
,
"test_pil_output_type"
,
"test_
datapoint
_output_type"
]
for
test_name
in
[
"test_pure_tensor_output_type"
,
"test_pil_output_type"
,
"test_
tv_tensor
_output_type"
]
]
]
multi_crop_skips
.
append
(
skip_dispatch_
datapoint
)
multi_crop_skips
.
append
(
skip_dispatch_
tv_tensor
)
def
xfails_pil
(
reason
,
*
,
condition
=
None
):
def
xfails_pil
(
reason
,
*
,
condition
=
None
):
...
@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
...
@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
crop
,
F
.
crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
crop_image
,
tv_tensor
s
.
Image
:
F
.
crop_image
,
datapoint
s
.
Video
:
F
.
crop_video
,
tv_tensor
s
.
Video
:
F
.
crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
crop_mask
,
tv_tensor
s
.
Mask
:
F
.
crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
resized_crop
,
F
.
resized_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
resized_crop_image
,
tv_tensor
s
.
Image
:
F
.
resized_crop_image
,
datapoint
s
.
Video
:
F
.
resized_crop_video
,
tv_tensor
s
.
Video
:
F
.
resized_crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
resized_crop_mask
,
tv_tensor
s
.
Mask
:
F
.
resized_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_resized_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_resized_crop_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
pad
,
F
.
pad
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
pad_image
,
tv_tensor
s
.
Image
:
F
.
pad_image
,
datapoint
s
.
Video
:
F
.
pad_video
,
tv_tensor
s
.
Video
:
F
.
pad_video
,
datapoint
s
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
pad_mask
,
tv_tensor
s
.
Mask
:
F
.
pad_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
test_marks
=
[
test_marks
=
[
...
@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
...
@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
perspective
,
F
.
perspective
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
perspective_image
,
tv_tensor
s
.
Image
:
F
.
perspective_image
,
datapoint
s
.
Video
:
F
.
perspective_video
,
tv_tensor
s
.
Video
:
F
.
perspective_video
,
datapoint
s
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
perspective_mask
,
tv_tensor
s
.
Mask
:
F
.
perspective_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_perspective_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_perspective_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
...
@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
elastic
,
F
.
elastic
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
elastic_image
,
tv_tensor
s
.
Image
:
F
.
elastic_image
,
datapoint
s
.
Video
:
F
.
elastic_video
,
tv_tensor
s
.
Video
:
F
.
elastic_video
,
datapoint
s
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
elastic_mask
,
tv_tensor
s
.
Mask
:
F
.
elastic_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_elastic_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_elastic_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
...
@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
...
@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
center_crop
,
F
.
center_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
center_crop_image
,
tv_tensor
s
.
Image
:
F
.
center_crop_image
,
datapoint
s
.
Video
:
F
.
center_crop_video
,
tv_tensor
s
.
Video
:
F
.
center_crop_video
,
datapoint
s
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
tv_tensor
s
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
datapoint
s
.
Mask
:
F
.
center_crop_mask
,
tv_tensor
s
.
Mask
:
F
.
center_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_center_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_center_crop_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
...
@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
gaussian_blur
,
F
.
gaussian_blur
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
gaussian_blur_image
,
tv_tensor
s
.
Image
:
F
.
gaussian_blur_image
,
datapoint
s
.
Video
:
F
.
gaussian_blur_video
,
tv_tensor
s
.
Video
:
F
.
gaussian_blur_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_gaussian_blur_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_gaussian_blur_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
...
@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
equalize
,
F
.
equalize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
equalize_image
,
tv_tensor
s
.
Image
:
F
.
equalize_image
,
datapoint
s
.
Video
:
F
.
equalize_video
,
tv_tensor
s
.
Video
:
F
.
equalize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
invert
,
F
.
invert
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
invert_image
,
tv_tensor
s
.
Image
:
F
.
invert_image
,
datapoint
s
.
Video
:
F
.
invert_video
,
tv_tensor
s
.
Video
:
F
.
invert_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
posterize
,
F
.
posterize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
posterize_image
,
tv_tensor
s
.
Image
:
F
.
posterize_image
,
datapoint
s
.
Video
:
F
.
posterize_video
,
tv_tensor
s
.
Video
:
F
.
posterize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
solarize
,
F
.
solarize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
solarize_image
,
tv_tensor
s
.
Image
:
F
.
solarize_image
,
datapoint
s
.
Video
:
F
.
solarize_video
,
tv_tensor
s
.
Video
:
F
.
solarize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
autocontrast
,
F
.
autocontrast
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
autocontrast_image
,
tv_tensor
s
.
Image
:
F
.
autocontrast_image
,
datapoint
s
.
Video
:
F
.
autocontrast_video
,
tv_tensor
s
.
Video
:
F
.
autocontrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_sharpness
,
F
.
adjust_sharpness
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_sharpness_image
,
tv_tensor
s
.
Image
:
F
.
adjust_sharpness_image
,
datapoint
s
.
Video
:
F
.
adjust_sharpness_video
,
tv_tensor
s
.
Video
:
F
.
adjust_sharpness_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
erase
,
F
.
erase
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
erase_image
,
tv_tensor
s
.
Image
:
F
.
erase_image
,
datapoint
s
.
Video
:
F
.
erase_video
,
tv_tensor
s
.
Video
:
F
.
erase_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_erase_image_pil
),
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_contrast
,
F
.
adjust_contrast
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_contrast_image
,
tv_tensor
s
.
Image
:
F
.
adjust_contrast_image
,
datapoint
s
.
Video
:
F
.
adjust_contrast_video
,
tv_tensor
s
.
Video
:
F
.
adjust_contrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_gamma
,
F
.
adjust_gamma
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_gamma_image
,
tv_tensor
s
.
Image
:
F
.
adjust_gamma_image
,
datapoint
s
.
Video
:
F
.
adjust_gamma_video
,
tv_tensor
s
.
Video
:
F
.
adjust_gamma_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_hue
,
F
.
adjust_hue
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_hue_image
,
tv_tensor
s
.
Image
:
F
.
adjust_hue_image
,
datapoint
s
.
Video
:
F
.
adjust_hue_video
,
tv_tensor
s
.
Video
:
F
.
adjust_hue_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_saturation
,
F
.
adjust_saturation
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
adjust_saturation_image
,
tv_tensor
s
.
Image
:
F
.
adjust_saturation_image
,
datapoint
s
.
Video
:
F
.
adjust_saturation_video
,
tv_tensor
s
.
Video
:
F
.
adjust_saturation_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
five_crop
,
F
.
five_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
five_crop_image
,
tv_tensor
s
.
Image
:
F
.
five_crop_image
,
datapoint
s
.
Video
:
F
.
five_crop_video
,
tv_tensor
s
.
Video
:
F
.
five_crop_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
_five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_five_crop_image_pil
),
test_marks
=
[
test_marks
=
[
...
@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
...
@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
ten_crop
,
F
.
ten_crop
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
ten_crop_image
,
tv_tensor
s
.
Image
:
F
.
ten_crop_image
,
datapoint
s
.
Video
:
F
.
ten_crop_video
,
tv_tensor
s
.
Video
:
F
.
ten_crop_video
,
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
...
@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
...
@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
normalize
,
F
.
normalize
,
kernels
=
{
kernels
=
{
datapoint
s
.
Image
:
F
.
normalize_image
,
tv_tensor
s
.
Image
:
F
.
normalize_image
,
datapoint
s
.
Video
:
F
.
normalize_video
,
tv_tensor
s
.
Video
:
F
.
normalize_video
,
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"mean"
),
xfail_jit_python_scalar_arg
(
"mean"
),
...
@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
...
@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
uniform_temporal_subsample
,
F
.
uniform_temporal_subsample
,
kernels
=
{
kernels
=
{
datapoint
s
.
Video
:
F
.
uniform_temporal_subsample_video
,
tv_tensor
s
.
Video
:
F
.
uniform_temporal_subsample_video
,
},
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
clamp_bounding_boxes
,
F
.
clamp_bounding_boxes
,
kernels
=
{
datapoint
s
.
BoundingBoxes
:
F
.
clamp_bounding_boxes
},
kernels
=
{
tv_tensor
s
.
BoundingBoxes
:
F
.
clamp_bounding_boxes
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
convert_bounding_box_format
,
F
.
convert_bounding_box_format
,
kernels
=
{
datapoint
s
.
BoundingBoxes
:
F
.
convert_bounding_box_format
},
kernels
=
{
tv_tensor
s
.
BoundingBoxes
:
F
.
convert_bounding_box_format
},
test_marks
=
[
test_marks
=
[
skip_dispatch_
datapoint
,
skip_dispatch_
tv_tensor
,
],
],
),
),
]
]
test/transforms_v2_kernel_infos.py
View file @
d5f4cc38
...
@@ -7,7 +7,7 @@ import pytest
...
@@ -7,7 +7,7 @@ import pytest
import
torch.testing
import
torch.testing
import
torchvision.ops
import
torchvision.ops
import
torchvision.transforms.v2.functional
as
F
import
torchvision.transforms.v2.functional
as
F
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
,
_parse_pad_padding
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
,
_parse_pad_padding
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
ArgsKwargs
,
ArgsKwargs
,
...
@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
bbox_xyxy
=
F
.
convert_bounding_box_format
(
bbox_xyxy
=
F
.
convert_bounding_box_format
(
bbox
.
as_subclass
(
torch
.
Tensor
),
bbox
.
as_subclass
(
torch
.
Tensor
),
old_format
=
format_
,
old_format
=
format_
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
,
inplace
=
True
,
)
)
points
=
np
.
array
(
points
=
np
.
array
(
...
@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
dtype
=
bbox_xyxy
.
dtype
,
dtype
=
bbox_xyxy
.
dtype
,
)
)
out_bbox
=
F
.
convert_bounding_box_format
(
out_bbox
=
F
.
convert_bounding_box_format
(
out_bbox
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
out_bbox
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
)
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox
=
F
.
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
)
out_bbox
=
F
.
clamp_bounding_boxes
(
out_bbox
,
format
=
format_
,
canvas_size
=
canvas_size_
)
...
@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
def
sample_inputs_convert_bounding_box_format
():
def
sample_inputs_convert_bounding_box_format
():
formats
=
list
(
datapoint
s
.
BoundingBoxFormat
)
formats
=
list
(
tv_tensor
s
.
BoundingBoxFormat
)
for
bounding_boxes_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
for
bounding_boxes_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
yield
ArgsKwargs
(
bounding_boxes_loader
,
old_format
=
bounding_boxes_loader
.
format
,
new_format
=
new_format
)
yield
ArgsKwargs
(
bounding_boxes_loader
,
old_format
=
bounding_boxes_loader
.
format
,
new_format
=
new_format
)
...
@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
...
@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
)
)
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
loader
=
make_bounding_box_loader
(
format
=
format
)
loader
=
make_bounding_box_loader
(
format
=
format
)
yield
ArgsKwargs
(
yield
ArgsKwargs
(
loader
,
format
=
format
,
canvas_size
=
loader
.
canvas_size
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
loader
,
format
=
format
,
canvas_size
=
loader
.
canvas_size
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
...
...
test/transforms_v2_legacy_utils.py
View file @
d5f4cc38
...
@@ -27,7 +27,7 @@ import PIL.Image
...
@@ -27,7 +27,7 @@ import PIL.Image
import
pytest
import
pytest
import
torch
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_pil_image
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_pil_image
...
@@ -82,7 +82,7 @@ def make_image(
...
@@ -82,7 +82,7 @@ def make_image(
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}:
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}:
data
[...,
-
1
,
:,
:]
=
max_value
data
[...,
-
1
,
:,
:]
=
max_value
return
datapoint
s
.
Image
(
data
)
return
tv_tensor
s
.
Image
(
data
)
def
make_image_tensor
(
*
args
,
**
kwargs
):
def
make_image_tensor
(
*
args
,
**
kwargs
):
...
@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
...
@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
def
make_bounding_boxes
(
def
make_bounding_boxes
(
canvas_size
=
DEFAULT_SIZE
,
canvas_size
=
DEFAULT_SIZE
,
*
,
*
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
batch_dims
=
(),
batch_dims
=
(),
dtype
=
None
,
dtype
=
None
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -107,12 +107,12 @@ def make_bounding_boxes(
...
@@ -107,12 +107,12 @@ def make_bounding_boxes(
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
flatten
().
tolist
()]).
reshape
(
values
.
shape
)
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
flatten
().
tolist
()]).
reshape
(
values
.
shape
)
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[
format
]
format
=
tv_tensor
s
.
BoundingBoxFormat
[
format
]
dtype
=
dtype
or
torch
.
float32
dtype
=
dtype
or
torch
.
float32
if
any
(
dim
==
0
for
dim
in
batch_dims
):
if
any
(
dim
==
0
for
dim
in
batch_dims
):
return
datapoint
s
.
BoundingBoxes
(
return
tv_tensor
s
.
BoundingBoxes
(
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
)
)
...
@@ -120,28 +120,28 @@ def make_bounding_boxes(
...
@@ -120,28 +120,28 @@ def make_bounding_boxes(
y
=
sample_position
(
h
,
canvas_size
[
0
])
y
=
sample_position
(
h
,
canvas_size
[
0
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
if
format
is
datapoint
s
.
BoundingBoxFormat
.
XYWH
:
if
format
is
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
:
parts
=
(
x
,
y
,
w
,
h
)
parts
=
(
x
,
y
,
w
,
h
)
elif
format
is
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
elif
format
is
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
x1
,
y1
=
x
,
y
x1
,
y1
=
x
,
y
x2
=
x1
+
w
x2
=
x1
+
w
y2
=
y1
+
h
y2
=
y1
+
h
parts
=
(
x1
,
y1
,
x2
,
y2
)
parts
=
(
x1
,
y1
,
x2
,
y2
)
elif
format
is
datapoint
s
.
BoundingBoxFormat
.
CXCYWH
:
elif
format
is
tv_tensor
s
.
BoundingBoxFormat
.
CXCYWH
:
cx
=
x
+
w
/
2
cx
=
x
+
w
/
2
cy
=
y
+
h
/
2
cy
=
y
+
h
/
2
parts
=
(
cx
,
cy
,
w
,
h
)
parts
=
(
cx
,
cy
,
w
,
h
)
else
:
else
:
raise
ValueError
(
f
"Format
{
format
}
is not supported"
)
raise
ValueError
(
f
"Format
{
format
}
is not supported"
)
return
datapoint
s
.
BoundingBoxes
(
return
tv_tensor
s
.
BoundingBoxes
(
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
)
)
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_objects
=
5
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_objects
=
5
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return
datapoint
s
.
Mask
(
return
tv_tensor
s
.
Mask
(
torch
.
testing
.
make_tensor
(
torch
.
testing
.
make_tensor
(
(
*
batch_dims
,
num_objects
,
*
size
),
(
*
batch_dims
,
num_objects
,
*
size
),
low
=
0
,
low
=
0
,
...
@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
...
@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def
make_segmentation_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_categories
=
10
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
def
make_segmentation_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_categories
=
10
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return
datapoint
s
.
Mask
(
return
tv_tensor
s
.
Mask
(
torch
.
testing
.
make_tensor
(
torch
.
testing
.
make_tensor
(
(
*
batch_dims
,
*
size
),
(
*
batch_dims
,
*
size
),
low
=
0
,
low
=
0
,
...
@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
...
@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def
make_video
(
size
=
DEFAULT_SIZE
,
*
,
num_frames
=
3
,
batch_dims
=
(),
**
kwargs
):
def
make_video
(
size
=
DEFAULT_SIZE
,
*
,
num_frames
=
3
,
batch_dims
=
(),
**
kwargs
):
return
datapoint
s
.
Video
(
make_image
(
size
,
batch_dims
=
(
*
batch_dims
,
num_frames
),
**
kwargs
))
return
tv_tensor
s
.
Video
(
make_image
(
size
,
batch_dims
=
(
*
batch_dims
,
num_frames
),
**
kwargs
))
def
make_video_tensor
(
*
args
,
**
kwargs
):
def
make_video_tensor
(
*
args
,
**
kwargs
):
...
@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
...
@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
return
datapoint
s
.
Image
(
image_tensor
)
return
tv_tensor
s
.
Image
(
image_tensor
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
...
@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
...
@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BoundingBoxesLoader
(
TensorLoader
):
class
BoundingBoxesLoader
(
TensorLoader
):
format
:
datapoint
s
.
BoundingBoxFormat
format
:
tv_tensor
s
.
BoundingBoxFormat
spatial_size
:
Tuple
[
int
,
int
]
spatial_size
:
Tuple
[
int
,
int
]
canvas_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
canvas_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
...
@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
...
@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
def
make_bounding_box_loader
(
*
,
extra_dims
=
(),
format
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtype
=
torch
.
float32
):
def
make_bounding_box_loader
(
*
,
extra_dims
=
(),
format
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtype
=
torch
.
float32
):
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[
format
]
format
=
tv_tensor
s
.
BoundingBoxFormat
[
format
]
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"spatial_size"
)
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"spatial_size"
)
...
@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
...
@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
def
make_bounding_box_loaders
(
def
make_bounding_box_loaders
(
*
,
*
,
extra_dims
=
tuple
(
d
for
d
in
DEFAULT_EXTRA_DIMS
if
len
(
d
)
<
2
),
extra_dims
=
tuple
(
d
for
d
in
DEFAULT_EXTRA_DIMS
if
len
(
d
)
<
2
),
formats
=
tuple
(
datapoint
s
.
BoundingBoxFormat
),
formats
=
tuple
(
tv_tensor
s
.
BoundingBoxFormat
),
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
):
):
...
...
torchvision/datasets/__init__.py
View file @
d5f4cc38
...
@@ -137,7 +137,7 @@ __all__ = (
...
@@ -137,7 +137,7 @@ __all__ = (
# Ref: https://peps.python.org/pep-0562/
# Ref: https://peps.python.org/pep-0562/
def
__getattr__
(
name
):
def
__getattr__
(
name
):
if
name
in
(
"wrap_dataset_for_transforms_v2"
,):
if
name
in
(
"wrap_dataset_for_transforms_v2"
,):
from
torchvision.
datapoint
s._dataset_wrapper
import
wrap_dataset_for_transforms_v2
from
torchvision.
tv_tensor
s._dataset_wrapper
import
wrap_dataset_for_transforms_v2
return
wrap_dataset_for_transforms_v2
return
wrap_dataset_for_transforms_v2
...
...
torchvision/prototype/__init__.py
View file @
d5f4cc38
from
.
import
datapoints
,
models
,
transforms
,
utils
from
.
import
models
,
transforms
,
tv_tensors
,
utils
torchvision/prototype/datasets/_builtin/caltech.py
View file @
d5f4cc38
...
@@ -6,8 +6,6 @@ import numpy as np
...
@@ -6,8 +6,6 @@ import numpy as np
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file
,
read_categories_file
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/celeba.py
View file @
d5f4cc38
...
@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
...
@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
path_accessor
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/cifar.py
View file @
d5f4cc38
...
@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
...
@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import
numpy
as
np
import
numpy
as
np
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.datapoints
import
Image
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
Image
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/clevr.py
View file @
d5f4cc38
...
@@ -2,7 +2,6 @@ import pathlib
...
@@ -2,7 +2,6 @@ import pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor
,
path_accessor
,
path_comparator
,
path_comparator
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/coco.py
View file @
d5f4cc38
...
@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
...
@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
UnBatcher
,
UnBatcher
,
)
)
from
torchvision.datapoints
import
BoundingBoxes
,
Mask
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor
,
path_accessor
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
,
Mask
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/country211.py
View file @
d5f4cc38
...
@@ -2,7 +2,6 @@ import pathlib
...
@@ -2,7 +2,6 @@ import pathlib
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/cub200.py
View file @
d5f4cc38
...
@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
...
@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.datapoints
import
BoundingBoxes
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file
,
read_categories_file
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
torchvision/prototype/datasets/_builtin/dtd.py
View file @
d5f4cc38
...
@@ -3,7 +3,6 @@ import pathlib
...
@@ -3,7 +3,6 @@ import pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchdata.datapipes.iter
import
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator
,
path_comparator
,
read_categories_file
,
read_categories_file
,
)
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
from
.._api
import
register_dataset
,
register_info
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment