Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
41f9c1e6
Unverified
Commit
41f9c1e6
authored
Aug 17, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 17, 2023
Browse files
Simple tensor -> pure tensor (#7846)
parent
4025fc5e
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
89 additions
and
89 deletions
+89
-89
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+5
-5
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+3
-3
test/test_transforms_v2.py
test/test_transforms_v2.py
+19
-19
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-1
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+13
-13
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+3
-3
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+1
-1
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-3
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+2
-2
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+3
-3
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+3
-3
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+2
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+2
-2
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+4
-4
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+11
-11
torchvision/transforms/v2/_type_conversion.py
torchvision/transforms/v2/_type_conversion.py
+3
-3
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+1
-1
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+6
-6
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+1
-1
torchvision/transforms/v2/utils.py
torchvision/transforms/v2/utils.py
+3
-3
No files found.
test/test_prototype_datasets_builtin.py
View file @
41f9c1e6
...
@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
...
@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.transforms.v2.utils
import
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
is_
pur
e_tensor
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
...
@@ -140,18 +140,18 @@ class TestCommon:
...
@@ -140,18 +140,18 @@ class TestCommon:
raise
AssertionError
(
make_msg_and_close
(
"The following streams were not closed after a full iteration:"
))
raise
AssertionError
(
make_msg_and_close
(
"The following streams were not closed after a full iteration:"
))
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_no_unaccompanied_
simpl
e_tensors
(
self
,
dataset_mock
,
config
):
def
test_no_unaccompanied_
pur
e_tensors
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
dataset
,
_
=
dataset_mock
.
load
(
config
)
sample
=
next_consume
(
iter
(
dataset
))
sample
=
next_consume
(
iter
(
dataset
))
simpl
e_tensors
=
{
key
for
key
,
value
in
sample
.
items
()
if
is_
simpl
e_tensor
(
value
)}
pur
e_tensors
=
{
key
for
key
,
value
in
sample
.
items
()
if
is_
pur
e_tensor
(
value
)}
if
simpl
e_tensors
and
not
any
(
if
pur
e_tensors
and
not
any
(
isinstance
(
item
,
(
datapoints
.
Image
,
datapoints
.
Video
,
EncodedImage
))
for
item
in
sample
.
values
()
isinstance
(
item
,
(
datapoints
.
Image
,
datapoints
.
Video
,
EncodedImage
))
for
item
in
sample
.
values
()
):
):
raise
AssertionError
(
raise
AssertionError
(
f
"The values of key(s) "
f
"The values of key(s) "
f
"
{
sequence_to_str
(
sorted
(
simpl
e_tensors
),
separate_last
=
'and '
)
}
contained
simpl
e tensors, "
f
"
{
sequence_to_str
(
sorted
(
pur
e_tensors
),
separate_last
=
'and '
)
}
contained
pur
e tensors, "
f
"but didn't find any (encoded) image or video."
f
"but didn't find any (encoded) image or video."
)
)
...
...
test/test_prototype_transforms.py
View file @
41f9c1e6
...
@@ -18,7 +18,7 @@ from prototype_common_utils import make_label
...
@@ -18,7 +18,7 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.utils
import
check_type
,
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
is_
pur
e_tensor
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
@@ -296,7 +296,7 @@ class TestPermuteDimensions:
...
@@ -296,7 +296,7 @@ class TestPermuteDimensions:
value_type
=
type
(
value
)
value_type
=
type
(
value
)
transformed_value
=
transformed_sample
[
key
]
transformed_value
=
transformed_sample
[
key
]
if
check_type
(
value
,
(
Image
,
is_
simpl
e_tensor
,
Video
)):
if
check_type
(
value
,
(
Image
,
is_
pur
e_tensor
,
Video
)):
if
transform
.
dims
.
get
(
value_type
)
is
not
None
:
if
transform
.
dims
.
get
(
value_type
)
is
not
None
:
assert
transformed_value
.
permute
(
inverse_dims
[
value_type
]).
equal
(
value
)
assert
transformed_value
.
permute
(
inverse_dims
[
value_type
]).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
assert
type
(
transformed_value
)
==
torch
.
Tensor
...
@@ -341,7 +341,7 @@ class TestTransposeDimensions:
...
@@ -341,7 +341,7 @@ class TestTransposeDimensions:
transformed_value
=
transformed_sample
[
key
]
transformed_value
=
transformed_sample
[
key
]
transposed_dims
=
transform
.
dims
.
get
(
value_type
)
transposed_dims
=
transform
.
dims
.
get
(
value_type
)
if
check_type
(
value
,
(
Image
,
is_
simpl
e_tensor
,
Video
)):
if
check_type
(
value
,
(
Image
,
is_
pur
e_tensor
,
Video
)):
if
transposed_dims
is
not
None
:
if
transposed_dims
is
not
None
:
assert
transformed_value
.
transpose
(
*
transposed_dims
).
equal
(
value
)
assert
transformed_value
.
transpose
(
*
transposed_dims
).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
assert
type
(
transformed_value
)
==
torch
.
Tensor
...
...
test/test_transforms_v2.py
View file @
41f9c1e6
...
@@ -29,7 +29,7 @@ from torchvision import datapoints
...
@@ -29,7 +29,7 @@ from torchvision import datapoints
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.utils
import
check_type
,
is_
simpl
e_tensor
,
query_chw
from
torchvision.transforms.v2.utils
import
check_type
,
is_
pur
e_tensor
,
query_chw
def
make_vanilla_tensor_images
(
*
args
,
**
kwargs
):
def
make_vanilla_tensor_images
(
*
args
,
**
kwargs
):
...
@@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
...
@@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
if
isinstance
(
value
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
if
isinstance
(
value
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
# AA transforms don't support bounding boxes or masks
# AA transforms don't support bounding boxes or masks
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_
simpl
e_tensor
,
PIL
.
Image
.
Image
)):
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_
pur
e_tensor
,
PIL
.
Image
.
Image
)):
if
image_or_video_found
:
if
image_or_video_found
:
# AA transforms only support a single image or video
# AA transforms only support a single image or video
continue
continue
...
@@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
...
@@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
# normalize doesn't support PIL images
# normalize doesn't support PIL images
continue
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_
simpl
e_tensor
)):
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_
pur
e_tensor
)):
# normalize doesn't support integer images
# normalize doesn't support integer images
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
value
=
F
.
to_dtype
(
value
,
torch
.
float32
,
scale
=
True
)
adapted_input
[
key
]
=
value
adapted_input
[
key
]
=
value
...
@@ -357,19 +357,19 @@ class TestSmoke:
...
@@ -357,19 +357,19 @@ class TestSmoke:
3
,
3
,
),
),
)
)
def
test_
simpl
e_tensor_heuristic
(
flat_inputs
):
def
test_
pur
e_tensor_heuristic
(
flat_inputs
):
def
split_on_
simpl
e_tensor
(
to_split
):
def
split_on_
pur
e_tensor
(
to_split
):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first
simpl
e tensor. If none is present, this will be `None`
# 1. The first
pur
e tensor. If none is present, this will be `None`
# 2. A list of the remaining
simpl
e tensors
# 2. A list of the remaining
pur
e tensors
# 3. A list of all other items
# 3. A list of all other items
simpl
e_tensors
=
[]
pur
e_tensors
=
[]
others
=
[]
others
=
[]
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
# affect the splitting.
for
item
,
inpt
in
zip
(
to_split
,
flat_inputs
):
for
item
,
inpt
in
zip
(
to_split
,
flat_inputs
):
(
simpl
e_tensors
if
is_
simpl
e_tensor
(
inpt
)
else
others
).
append
(
item
)
(
pur
e_tensors
if
is_
pur
e_tensor
(
inpt
)
else
others
).
append
(
item
)
return
simpl
e_tensors
[
0
]
if
simpl
e_tensors
else
None
,
simpl
e_tensors
[
1
:],
others
return
pur
e_tensors
[
0
]
if
pur
e_tensors
else
None
,
pur
e_tensors
[
1
:],
others
class
CopyCloneTransform
(
transforms
.
Transform
):
class
CopyCloneTransform
(
transforms
.
Transform
):
def
_transform
(
self
,
inpt
,
params
):
def
_transform
(
self
,
inpt
,
params
):
...
@@ -385,20 +385,20 @@ def test_simple_tensor_heuristic(flat_inputs):
...
@@ -385,20 +385,20 @@ def test_simple_tensor_heuristic(flat_inputs):
assert_equal
(
output
,
inpt
)
assert_equal
(
output
,
inpt
)
return
True
return
True
first_
simpl
e_tensor_input
,
other_
simpl
e_tensor_inputs
,
other_inputs
=
split_on_
simpl
e_tensor
(
flat_inputs
)
first_
pur
e_tensor_input
,
other_
pur
e_tensor_inputs
,
other_inputs
=
split_on_
pur
e_tensor
(
flat_inputs
)
transform
=
CopyCloneTransform
()
transform
=
CopyCloneTransform
()
transformed_sample
=
transform
(
flat_inputs
)
transformed_sample
=
transform
(
flat_inputs
)
first_
simpl
e_tensor_output
,
other_
simpl
e_tensor_outputs
,
other_outputs
=
split_on_
simpl
e_tensor
(
transformed_sample
)
first_
pur
e_tensor_output
,
other_
pur
e_tensor_outputs
,
other_outputs
=
split_on_
pur
e_tensor
(
transformed_sample
)
if
first_
simpl
e_tensor_input
is
not
None
:
if
first_
pur
e_tensor_input
is
not
None
:
if
other_inputs
:
if
other_inputs
:
assert
not
transform
.
was_applied
(
first_
simpl
e_tensor_output
,
first_
simpl
e_tensor_input
)
assert
not
transform
.
was_applied
(
first_
pur
e_tensor_output
,
first_
pur
e_tensor_input
)
else
:
else
:
assert
transform
.
was_applied
(
first_
simpl
e_tensor_output
,
first_
simpl
e_tensor_input
)
assert
transform
.
was_applied
(
first_
pur
e_tensor_output
,
first_
pur
e_tensor_input
)
for
output
,
inpt
in
zip
(
other_
simpl
e_tensor_outputs
,
other_
simpl
e_tensor_inputs
):
for
output
,
inpt
in
zip
(
other_
pur
e_tensor_outputs
,
other_
pur
e_tensor_inputs
):
assert
not
transform
.
was_applied
(
output
,
inpt
)
assert
not
transform
.
was_applied
(
output
,
inpt
)
for
input
,
output
in
zip
(
other_inputs
,
other_outputs
):
for
input
,
output
in
zip
(
other_inputs
,
other_outputs
):
...
@@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
...
@@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
image
=
image
.
as_subclass
(
torch
.
Tensor
)
image
=
image
.
as_subclass
(
torch
.
Tensor
)
assert
is_
simpl
e_tensor
(
image
)
assert
is_
pur
e_tensor
(
image
)
label
=
1
if
label_type
is
int
else
torch
.
tensor
([
1
])
label
=
1
if
label_type
is
int
else
torch
.
tensor
([
1
])
...
@@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
image
=
to_pil_image
(
image
[
0
])
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
elif
image_type
is
torch
.
Tensor
:
image
=
image
.
as_subclass
(
torch
.
Tensor
)
image
=
image
.
as_subclass
(
torch
.
Tensor
)
assert
is_
simpl
e_tensor
(
image
)
assert
is_
pur
e_tensor
(
image
)
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
...
@@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out
=
t
(
sample
)
out
=
t
(
sample
)
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
datapoints
.
Image
:
if
isinstance
(
to_tensor
,
transforms
.
ToTensor
)
and
image_type
is
not
datapoints
.
Image
:
assert
is_
simpl
e_tensor
(
out
[
"image"
])
assert
is_
pur
e_tensor
(
out
[
"image"
])
else
:
else
:
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
...
...
test/test_transforms_v2_consistency.py
View file @
41f9c1e6
...
@@ -602,7 +602,7 @@ def check_call_consistency(
...
@@ -602,7 +602,7 @@ def check_call_consistency(
raise
AssertionError
(
raise
AssertionError
(
f
"Transforming a tensor image with shape
{
image_repr
}
failed in the prototype transform with "
f
"Transforming a tensor image with shape
{
image_repr
}
failed in the prototype transform with "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"the error above. This means there is a consistency bug either in `_get_params` or in the "
f
"`is_
simpl
e_tensor` path in `_transform`."
f
"`is_
pur
e_tensor` path in `_transform`."
)
from
exc
)
from
exc
assert_close
(
assert_close
(
...
...
test/test_transforms_v2_functional.py
View file @
41f9c1e6
...
@@ -24,7 +24,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
...
@@ -24,7 +24,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
from
torchvision.transforms.v2.utils
import
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
is_
pur
e_tensor
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
...
@@ -168,7 +168,7 @@ class TestKernels:
...
@@ -168,7 +168,7 @@ class TestKernels:
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
def
test_batched_vs_single
(
self
,
test_id
,
info
,
args_kwargs
,
device
):
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
batched_input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
datapoint_type
=
datapoints
.
Image
if
is_
simpl
e_tensor
(
batched_input
)
else
type
(
batched_input
)
datapoint_type
=
datapoints
.
Image
if
is_
pur
e_tensor
(
batched_input
)
else
type
(
batched_input
)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
# Everything to the left is considered a batch dimension.
data_dims
=
{
data_dims
=
{
...
@@ -333,9 +333,9 @@ class TestDispatchers:
...
@@ -333,9 +333,9 @@ class TestDispatchers:
dispatcher
=
script
(
info
.
dispatcher
)
dispatcher
=
script
(
info
.
dispatcher
)
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
device
)
image_
simpl
e_tensor
=
torch
.
Tensor
(
image_datapoint
)
image_
pur
e_tensor
=
torch
.
Tensor
(
image_datapoint
)
dispatcher
(
image_
simpl
e_tensor
,
*
other_args
,
**
kwargs
)
dispatcher
(
image_
pur
e_tensor
,
*
other_args
,
**
kwargs
)
# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them.
# replaces this test for them.
...
@@ -358,11 +358,11 @@ class TestDispatchers:
...
@@ -358,11 +358,11 @@ class TestDispatchers:
script
(
dispatcher
)
script
(
dispatcher
)
@
image_sample_inputs
@
image_sample_inputs
def
test_
simpl
e_tensor_output_type
(
self
,
info
,
args_kwargs
):
def
test_
pur
e_tensor_output_type
(
self
,
info
,
args_kwargs
):
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_
simpl
e_tensor
=
image_datapoint
.
as_subclass
(
torch
.
Tensor
)
image_
pur
e_tensor
=
image_datapoint
.
as_subclass
(
torch
.
Tensor
)
output
=
info
.
dispatcher
(
image_
simpl
e_tensor
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_
pur
e_tensor
,
*
other_args
,
**
kwargs
)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
output
)
is
torch
.
Tensor
...
@@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
...
@@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
dict
(
canvas_size
=
(
1
,
1
)),
dict
(
canvas_size
=
(
1
,
1
)),
],
],
)
)
def
test_
simpl
e_tensor_insufficient_metadata
(
self
,
metadata
):
def
test_
pur
e_tensor_insufficient_metadata
(
self
,
metadata
):
simpl
e_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
pur
e_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` has to be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `canvas_size` has to be passed"
)):
F
.
clamp_bounding_boxes
(
simpl
e_tensor
,
**
metadata
)
F
.
clamp_bounding_boxes
(
pur
e_tensor
,
**
metadata
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"metadata"
,
"metadata"
,
...
@@ -538,11 +538,11 @@ class TestConvertFormatBoundingBoxes:
...
@@ -538,11 +538,11 @@ class TestConvertFormatBoundingBoxes:
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
"missing 1 required argument: 'new_format'"
)):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
"missing 1 required argument: 'new_format'"
)):
F
.
convert_format_bounding_boxes
(
inpt
,
old_format
)
F
.
convert_format_bounding_boxes
(
inpt
,
old_format
)
def
test_
simpl
e_tensor_insufficient_metadata
(
self
):
def
test_
pur
e_tensor_insufficient_metadata
(
self
):
simpl
e_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
pur
e_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
F
.
convert_format_bounding_boxes
(
simpl
e_tensor
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
F
.
convert_format_bounding_boxes
(
pur
e_tensor
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
def
test_datapoint_explicit_metadata
(
self
):
def
test_datapoint_explicit_metadata
(
self
):
datapoint
=
next
(
make_bounding_boxes
())
datapoint
=
next
(
make_bounding_boxes
())
...
...
test/test_transforms_v2_utils.py
View file @
41f9c1e6
...
@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
...
@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoints
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoints
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
simpl
e_tensor
),
True
),
((
IMAGE
,),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
pur
e_tensor
),
True
),
(
(
(
torch
.
Tensor
(
IMAGE
),),
(
torch
.
Tensor
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
simpl
e_tensor
),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
pur
e_tensor
),
True
,
True
,
),
),
(
(
(
to_pil_image
(
IMAGE
),),
(
to_pil_image
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
simpl
e_tensor
),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_
pur
e_tensor
),
True
,
True
,
),
),
],
],
...
...
test/transforms_v2_dispatcher_infos.py
View file @
41f9c1e6
...
@@ -107,7 +107,7 @@ multi_crop_skips = [
...
@@ -107,7 +107,7 @@ multi_crop_skips = [
(
"TestDispatchers"
,
test_name
),
(
"TestDispatchers"
,
test_name
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
pytest
.
mark
.
skip
(
reason
=
"Multi-crop dispatchers return a sequence of items rather than a single one."
),
)
)
for
test_name
in
[
"test_
simpl
e_tensor_output_type"
,
"test_pil_output_type"
,
"test_datapoint_output_type"
]
for
test_name
in
[
"test_
pur
e_tensor_output_type"
,
"test_pil_output_type"
,
"test_datapoint_output_type"
]
]
]
multi_crop_skips
.
append
(
skip_dispatch_datapoint
)
multi_crop_skips
.
append
(
skip_dispatch_datapoint
)
...
...
torchvision/prototype/transforms/_augment.py
View file @
41f9c1e6
...
@@ -9,7 +9,7 @@ from torchvision.prototype import datapoints as proto_datapoints
...
@@ -9,7 +9,7 @@ from torchvision.prototype import datapoints as proto_datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.utils
import
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
is_
pur
e_tensor
class
SimpleCopyPaste
(
Transform
):
class
SimpleCopyPaste
(
Transform
):
...
@@ -109,7 +109,7 @@ class SimpleCopyPaste(Transform):
...
@@ -109,7 +109,7 @@ class SimpleCopyPaste(Transform):
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
for
obj
in
flat_sample
:
for
obj
in
flat_sample
:
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_
simpl
e_tensor
(
obj
):
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_
pur
e_tensor
(
obj
):
images
.
append
(
obj
)
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
F
.
to_image
(
obj
))
images
.
append
(
F
.
to_image
(
obj
))
...
@@ -146,7 +146,7 @@ class SimpleCopyPaste(Transform):
...
@@ -146,7 +146,7 @@ class SimpleCopyPaste(Transform):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
flat_sample
[
i
]
=
F
.
to_pil_image
(
output_images
[
c0
])
flat_sample
[
i
]
=
F
.
to_pil_image
(
output_images
[
c0
])
c0
+=
1
c0
+=
1
elif
is_
simpl
e_tensor
(
obj
):
elif
is_
pur
e_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
flat_sample
[
i
]
=
output_images
[
c0
]
c0
+=
1
c0
+=
1
elif
isinstance
(
obj
,
datapoints
.
BoundingBoxes
):
elif
isinstance
(
obj
,
datapoints
.
BoundingBoxes
):
...
...
torchvision/prototype/transforms/_geometry.py
View file @
41f9c1e6
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_
simpl
e_tensor
,
query_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_
pur
e_tensor
,
query_size
class
FixedSizeCrop
(
Transform
):
class
FixedSizeCrop
(
Transform
):
...
@@ -32,7 +32,7 @@ class FixedSizeCrop(Transform):
...
@@ -32,7 +32,7 @@ class FixedSizeCrop(Transform):
flat_inputs
,
flat_inputs
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Image
,
is_
simpl
e_tensor
,
is_
pur
e_tensor
,
datapoints
.
Video
,
datapoints
.
Video
,
):
):
raise
TypeError
(
raise
TypeError
(
...
...
torchvision/prototype/transforms/_misc.py
View file @
41f9c1e6
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2.utils
import
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
is_
pur
e_tensor
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
@@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
...
@@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
class
PermuteDimensions
(
Transform
):
class
PermuteDimensions
(
Transform
):
_transformed_types
=
(
is_
simpl
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
_transformed_types
=
(
is_
pur
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
dims
:
Union
[
Sequence
[
int
],
Dict
[
Type
,
Optional
[
Sequence
[
int
]]]])
->
None
:
def
__init__
(
self
,
dims
:
Union
[
Sequence
[
int
],
Dict
[
Type
,
Optional
[
Sequence
[
int
]]]])
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -47,7 +47,7 @@ class PermuteDimensions(Transform):
...
@@ -47,7 +47,7 @@ class PermuteDimensions(Transform):
class
TransposeDimensions
(
Transform
):
class
TransposeDimensions
(
Transform
):
_transformed_types
=
(
is_
simpl
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
_transformed_types
=
(
is_
pur
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
dims
:
Union
[
Tuple
[
int
,
int
],
Dict
[
Type
,
Optional
[
Tuple
[
int
,
int
]]]])
->
None
:
def
__init__
(
self
,
dims
:
Union
[
Tuple
[
int
,
int
],
Dict
[
Type
,
Optional
[
Tuple
[
int
,
int
]]]])
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/transforms/v2/_augment.py
View file @
41f9c1e6
...
@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F
...
@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F
from
._transform
import
_RandomApplyTransform
,
Transform
from
._transform
import
_RandomApplyTransform
,
Transform
from
._utils
import
_parse_labels_getter
from
._utils
import
_parse_labels_getter
from
.utils
import
has_any
,
is_
simpl
e_tensor
,
query_chw
,
query_size
from
.utils
import
has_any
,
is_
pur
e_tensor
,
query_chw
,
query_size
class
RandomErasing
(
_RandomApplyTransform
):
class
RandomErasing
(
_RandomApplyTransform
):
...
@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix):
...
@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix):
if
inpt
is
params
[
"labels"
]:
if
inpt
is
params
[
"labels"
]:
return
self
.
_mixup_label
(
inpt
,
lam
=
lam
)
return
self
.
_mixup_label
(
inpt
,
lam
=
lam
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_
simpl
e_tensor
(
inpt
):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_
pur
e_tensor
(
inpt
):
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
output
=
inpt
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
inpt
.
mul
(
lam
))
output
=
inpt
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
inpt
.
mul
(
lam
))
...
@@ -310,7 +310,7 @@ class CutMix(_BaseMixUpCutMix):
...
@@ -310,7 +310,7 @@ class CutMix(_BaseMixUpCutMix):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
inpt
is
params
[
"labels"
]:
if
inpt
is
params
[
"labels"
]:
return
self
.
_mixup_label
(
inpt
,
lam
=
params
[
"lam_adjusted"
])
return
self
.
_mixup_label
(
inpt
,
lam
=
params
[
"lam_adjusted"
])
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_
simpl
e_tensor
(
inpt
):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_
pur
e_tensor
(
inpt
):
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
x1
,
y1
,
x2
,
y2
=
params
[
"box"
]
x1
,
y1
,
x2
,
y2
=
params
[
"box"
]
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
41f9c1e6
...
@@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._meta import get_size
...
@@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._meta import get_size
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_
simpl
e_tensor
from
.utils
import
check_type
,
is_
pur
e_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
...
@@ -50,7 +50,7 @@ class _AutoAugmentBase(Transform):
...
@@ -50,7 +50,7 @@ class _AutoAugmentBase(Transform):
(
(
datapoints
.
Image
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
is_
simpl
e_tensor
,
is_
pur
e_tensor
,
datapoints
.
Video
,
datapoints
.
Video
,
),
),
):
):
...
...
torchvision/transforms/v2/_geometry.py
View file @
41f9c1e6
...
@@ -24,7 +24,7 @@ from ._utils import (
...
@@ -24,7 +24,7 @@ from ._utils import (
_setup_float_or_seq
,
_setup_float_or_seq
,
_setup_size
,
_setup_size
,
)
)
from
.utils
import
get_bounding_boxes
,
has_all
,
has_any
,
is_
simpl
e_tensor
,
query_size
from
.utils
import
get_bounding_boxes
,
has_all
,
has_any
,
is_
pur
e_tensor
,
query_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
@@ -1149,7 +1149,7 @@ class RandomIoUCrop(Transform):
...
@@ -1149,7 +1149,7 @@ class RandomIoUCrop(Transform):
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
not
(
if
not
(
has_all
(
flat_inputs
,
datapoints
.
BoundingBoxes
)
has_all
(
flat_inputs
,
datapoints
.
BoundingBoxes
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
is_
simpl
e_tensor
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
is_
pur
e_tensor
)
):
):
raise
TypeError
(
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain tensor or PIL images "
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain tensor or PIL images "
...
...
torchvision/transforms/v2/_misc.py
View file @
41f9c1e6
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
.utils
import
get_bounding_boxes
,
has_any
,
is_
simpl
e_tensor
from
.utils
import
get_bounding_boxes
,
has_any
,
is_
pur
e_tensor
# TODO: do we want/need to expose this?
# TODO: do we want/need to expose this?
...
@@ -75,7 +75,7 @@ class LinearTransformation(Transform):
...
@@ -75,7 +75,7 @@ class LinearTransformation(Transform):
_v1_transform_cls
=
_transforms
.
LinearTransformation
_v1_transform_cls
=
_transforms
.
LinearTransformation
_transformed_types
=
(
is_
simpl
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
_transformed_types
=
(
is_
pur
e_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
super
().
__init__
()
super
().
__init__
()
...
@@ -264,7 +264,7 @@ class ToDtype(Transform):
...
@@ -264,7 +264,7 @@ class ToDtype(Transform):
if
isinstance
(
self
.
dtype
,
torch
.
dtype
):
if
isinstance
(
self
.
dtype
,
torch
.
dtype
):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
# is a simple torch.dtype
if
not
is_
simpl
e_tensor
(
inpt
)
and
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
not
is_
pur
e_tensor
(
inpt
)
and
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
return
inpt
dtype
:
Optional
[
torch
.
dtype
]
=
self
.
dtype
dtype
:
Optional
[
torch
.
dtype
]
=
self
.
dtype
...
@@ -281,7 +281,7 @@ class ToDtype(Transform):
...
@@ -281,7 +281,7 @@ class ToDtype(Transform):
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
)
supports_scaling
=
is_
simpl
e_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
supports_scaling
=
is_
pur
e_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
if
dtype
is
None
:
if
dtype
is
None
:
if
self
.
scale
and
supports_scaling
:
if
self
.
scale
and
supports_scaling
:
warnings
.
warn
(
warnings
.
warn
(
...
...
torchvision/transforms/v2/_transform.py
View file @
41f9c1e6
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_
pur
e_tensor
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
.functional._utils
import
_get_kernel
from
.functional._utils
import
_get_kernel
...
@@ -55,32 +55,32 @@ class Transform(nn.Module):
...
@@ -55,32 +55,32 @@ class Transform(nn.Module):
return
tree_unflatten
(
flat_outputs
,
spec
)
return
tree_unflatten
(
flat_outputs
,
spec
)
def
_needs_transform_list
(
self
,
flat_inputs
:
List
[
Any
])
->
List
[
bool
]:
def
_needs_transform_list
(
self
,
flat_inputs
:
List
[
Any
])
->
List
[
bool
]:
# Below is a heuristic on how to deal with
simpl
e tensor inputs:
# Below is a heuristic on how to deal with
pur
e tensor inputs:
# 1.
Simpl
e tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# 1.
Pur
e tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered
simpl
e tensor is
# 2. If there is no explicit image or video in the sample, only the first encountered
pur
e tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
# of `tree_flatten`, which recurses depth-first through the input.
#
#
# This heuristic stems from two requirements:
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input
simpl
e tensors and treat them as images.
# 1. We need to keep BC for single input
pur
e tensors and treat them as images.
# 2. We don't want to treat all
simpl
e tensors as images, because some datasets like `CelebA` or `Widerface`
# 2. We don't want to treat all
pur
e tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
# return supplemental numerical data as tensors that cannot be transformed as images.
#
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple
simpl
e tensors at the same time, expecting them all to be treated as images.
# tries to transform multiple
pur
e tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list
=
[]
needs_transform_list
=
[]
transform_
simpl
e_tensor
=
not
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoints
.
Video
,
PIL
.
Image
.
Image
)
transform_
pur
e_tensor
=
not
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoints
.
Video
,
PIL
.
Image
.
Image
)
for
inpt
in
flat_inputs
:
for
inpt
in
flat_inputs
:
needs_transform
=
True
needs_transform
=
True
if
not
check_type
(
inpt
,
self
.
_transformed_types
):
if
not
check_type
(
inpt
,
self
.
_transformed_types
):
needs_transform
=
False
needs_transform
=
False
elif
is_
simpl
e_tensor
(
inpt
):
elif
is_
pur
e_tensor
(
inpt
):
if
transform_
simpl
e_tensor
:
if
transform_
pur
e_tensor
:
transform_
simpl
e_tensor
=
False
transform_
pur
e_tensor
=
False
else
:
else
:
needs_transform
=
False
needs_transform
=
False
needs_transform_list
.
append
(
needs_transform
)
needs_transform_list
.
append
(
needs_transform
)
...
...
torchvision/transforms/v2/_type_conversion.py
View file @
41f9c1e6
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2.utils
import
is_
simpl
e_tensor
from
torchvision.transforms.v2.utils
import
is_
pur
e_tensor
class
PILToTensor
(
Transform
):
class
PILToTensor
(
Transform
):
...
@@ -35,7 +35,7 @@ class ToImage(Transform):
...
@@ -35,7 +35,7 @@ class ToImage(Transform):
This transform does not support torchscript.
This transform does not support torchscript.
"""
"""
_transformed_types
=
(
is_
simpl
e_tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
is_
pur
e_tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
)
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
...
@@ -65,7 +65,7 @@ class ToPILImage(Transform):
...
@@ -65,7 +65,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
"""
_transformed_types
=
(
is_
simpl
e_tensor
,
datapoints
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
is_
pur
e_tensor
,
datapoints
.
Image
,
np
.
ndarray
)
def
__init__
(
self
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
def
__init__
(
self
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
41f9c1e6
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
._utils
import
is_
simpl
e_tensor
,
register_kernel
# usort: skip
from
._utils
import
is_
pur
e_tensor
,
register_kernel
# usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_boxes
,
clamp_bounding_boxes
,
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
41f9c1e6
...
@@ -8,7 +8,7 @@ from torchvision.transforms import _functional_pil as _FP
...
@@ -8,7 +8,7 @@ from torchvision.transforms import _functional_pil as _FP
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
is_
simpl
e_tensor
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
is_
pur
e_tensor
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
...
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
...
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for
simpl
e tensor
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for
pur
e tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
# default error that would be thrown if `new_format` had no default value.
...
@@ -213,9 +213,9 @@ def convert_format_bounding_boxes(
...
@@ -213,9 +213,9 @@ def convert_format_bounding_boxes(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_boxes
)
_log_api_usage_once
(
convert_format_bounding_boxes
)
if
torch
.
jit
.
is_scripting
()
or
is_
simpl
e_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_
pur
e_tensor
(
inpt
):
if
old_format
is
None
:
if
old_format
is
None
:
raise
ValueError
(
"For
simpl
e tensor inputs, `old_format` has to be passed."
)
raise
ValueError
(
"For
pur
e tensor inputs, `old_format` has to be passed."
)
return
_convert_format_bounding_boxes
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
return
_convert_format_bounding_boxes
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
):
elif
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
):
if
old_format
is
not
None
:
if
old_format
is
not
None
:
...
@@ -256,10 +256,10 @@ def clamp_bounding_boxes(
...
@@ -256,10 +256,10 @@ def clamp_bounding_boxes(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_boxes
)
_log_api_usage_once
(
clamp_bounding_boxes
)
if
torch
.
jit
.
is_scripting
()
or
is_
simpl
e_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_
pur
e_tensor
(
inpt
):
if
format
is
None
or
canvas_size
is
None
:
if
format
is
None
or
canvas_size
is
None
:
raise
ValueError
(
"For
simpl
e tensor inputs, `format` and `canvas_size` has to be passed."
)
raise
ValueError
(
"For
pur
e tensor inputs, `format` and `canvas_size` has to be passed."
)
return
_clamp_bounding_boxes
(
inpt
,
format
=
format
,
canvas_size
=
canvas_size
)
return
_clamp_bounding_boxes
(
inpt
,
format
=
format
,
canvas_size
=
canvas_size
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
):
elif
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
):
if
format
is
not
None
or
canvas_size
is
not
None
:
if
format
is
not
None
or
canvas_size
is
not
None
:
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
41f9c1e6
...
@@ -8,7 +8,7 @@ _FillType = Union[int, float, Sequence[int], Sequence[float], None]
...
@@ -8,7 +8,7 @@ _FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT
=
Optional
[
List
[
float
]]
_FillTypeJIT
=
Optional
[
List
[
float
]]
def
is_
simpl
e_tensor
(
inpt
:
Any
)
->
bool
:
def
is_
pur
e_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
...
...
torchvision/transforms/v2/utils.py
View file @
41f9c1e6
...
@@ -6,7 +6,7 @@ import PIL.Image
...
@@ -6,7 +6,7 @@ import PIL.Image
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_
simpl
e_tensor
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_
pur
e_tensor
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
...
@@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
...
@@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws
=
{
chws
=
{
tuple
(
get_dimensions
(
inpt
))
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_
simpl
e_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
if
check_type
(
inpt
,
(
is_
pur
e_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
}
}
if
not
chws
:
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
raise
TypeError
(
"No image or video was found in the sample"
)
...
@@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
...
@@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
if
check_type
(
if
check_type
(
inpt
,
inpt
,
(
(
is_
simpl
e_tensor
,
is_
pur
e_tensor
,
datapoints
.
Image
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Video
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment