Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
312c3d32
Unverified
Commit
312c3d32
authored
Aug 01, 2023
by
Philip Meier
Committed by
GitHub
Aug 01, 2023
Browse files
remove spatial_size (#7734)
parent
bdf16222
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
156 additions
and
131 deletions
+156
-131
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+5
-5
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+10
-10
torchvision/transforms/v2/_meta.py
torchvision/transforms/v2/_meta.py
+1
-1
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-1
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+6
-6
torchvision/transforms/v2/functional/_deprecated.py
torchvision/transforms/v2/functional/_deprecated.py
+1
-1
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+45
-45
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+71
-53
torchvision/transforms/v2/utils.py
torchvision/transforms/v2/utils.py
+16
-9
No files found.
torchvision/transforms/v2/_auto_augment.py
View file @
312c3d32
...
...
@@ -9,7 +9,7 @@ from torchvision import datapoints, transforms as _transforms
from
torchvision.transforms
import
_functional_tensor
as
_FT
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_
spatial_
size
from
torchvision.transforms.v2.functional._meta
import
get_size
from
._utils
import
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
...
...
@@ -312,7 +312,7 @@ class AutoAugment(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_
spatial_
size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
policy
=
self
.
_policies
[
int
(
torch
.
randint
(
len
(
self
.
_policies
),
()))]
...
...
@@ -403,7 +403,7 @@ class RandAugment(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_
spatial_
size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
for
_
in
range
(
self
.
num_ops
):
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
...
...
@@ -474,7 +474,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_
spatial_
size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
...
...
@@ -568,7 +568,7 @@ class AugMix(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
orig_image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_
spatial_
size
(
orig_image_or_video
)
height
,
width
=
get_size
(
orig_image_or_video
)
if
isinstance
(
orig_image_or_video
,
torch
.
Tensor
):
image_or_video
=
orig_image_or_video
...
...
torchvision/transforms/v2/_geometry.py
View file @
312c3d32
...
...
@@ -22,7 +22,7 @@ from ._utils import (
_setup_float_or_seq
,
_setup_size
,
)
from
.utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_
spatial_
size
from
.utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
...
@@ -267,7 +267,7 @@ class RandomResizedCrop(Transform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_
spatial_
size
(
flat_inputs
)
height
,
width
=
query_size
(
flat_inputs
)
area
=
height
*
width
log_ratio
=
self
.
_log_ratio
...
...
@@ -558,7 +558,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_h
,
orig_w
=
query_
spatial_
size
(
flat_inputs
)
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
r
=
self
.
side_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
side_range
[
1
]
-
self
.
side_range
[
0
])
canvas_width
=
int
(
orig_w
*
r
)
...
...
@@ -735,7 +735,7 @@ class RandomAffine(Transform):
self
.
center
=
center
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_
spatial_
size
(
flat_inputs
)
height
,
width
=
query_size
(
flat_inputs
)
angle
=
torch
.
empty
(
1
).
uniform_
(
self
.
degrees
[
0
],
self
.
degrees
[
1
]).
item
()
if
self
.
translate
is
not
None
:
...
...
@@ -859,7 +859,7 @@ class RandomCrop(Transform):
self
.
padding_mode
=
padding_mode
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
padded_height
,
padded_width
=
query_
spatial_
size
(
flat_inputs
)
padded_height
,
padded_width
=
query_size
(
flat_inputs
)
if
self
.
padding
is
not
None
:
pad_left
,
pad_right
,
pad_top
,
pad_bottom
=
self
.
padding
...
...
@@ -972,7 +972,7 @@ class RandomPerspective(_RandomApplyTransform):
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_
spatial_
size
(
flat_inputs
)
height
,
width
=
query_size
(
flat_inputs
)
distortion_scale
=
self
.
distortion_scale
...
...
@@ -1072,7 +1072,7 @@ class ElasticTransform(Transform):
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
size
=
list
(
query_
spatial_
size
(
flat_inputs
))
size
=
list
(
query_size
(
flat_inputs
))
dx
=
torch
.
rand
([
1
,
1
]
+
size
)
*
2
-
1
if
self
.
sigma
[
0
]
>
0.0
:
...
...
@@ -1164,7 +1164,7 @@ class RandomIoUCrop(Transform):
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_h
,
orig_w
=
query_
spatial_
size
(
flat_inputs
)
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
bboxes
=
query_bounding_boxes
(
flat_inputs
)
while
True
:
...
...
@@ -1282,7 +1282,7 @@ class ScaleJitter(Transform):
self
.
antialias
=
antialias
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_height
,
orig_width
=
query_
spatial_
size
(
flat_inputs
)
orig_height
,
orig_width
=
query_size
(
flat_inputs
)
scale
=
self
.
scale_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
scale_range
[
1
]
-
self
.
scale_range
[
0
])
r
=
min
(
self
.
target_size
[
1
]
/
orig_height
,
self
.
target_size
[
0
]
/
orig_width
)
*
scale
...
...
@@ -1347,7 +1347,7 @@ class RandomShortestSize(Transform):
self
.
antialias
=
antialias
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_height
,
orig_width
=
query_
spatial_
size
(
flat_inputs
)
orig_height
,
orig_width
=
query_size
(
flat_inputs
)
min_size
=
self
.
min_size
[
int
(
torch
.
randint
(
len
(
self
.
min_size
),
()))]
r
=
min_size
/
min
(
orig_height
,
orig_width
)
...
...
torchvision/transforms/v2/_meta.py
View file @
312c3d32
...
...
@@ -30,7 +30,7 @@ class ConvertBoundingBoxFormat(Transform):
class
ClampBoundingBoxes
(
Transform
):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
The clamping is done according to the bounding boxes' ``
spatial
_size`` meta-data.
The clamping is done according to the bounding boxes' ``
canvas
_size`` meta-data.
.. v2betastatus:: ClampBoundingBoxes transform
...
...
torchvision/transforms/v2/_misc.py
View file @
312c3d32
...
...
@@ -408,7 +408,7 @@ class SanitizeBoundingBoxes(Transform):
valid
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h
,
image_w
=
boxes
.
spatial
_size
image_h
,
image_w
=
boxes
.
canvas
_size
valid
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
valid
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
312c3d32
...
...
@@ -15,12 +15,12 @@ from ._meta import (
get_num_channels_image_pil
,
get_num_channels_video
,
get_num_channels
,
get_
spatial_
size_bounding_boxes
,
get_
spatial_
size_image_tensor
,
get_
spatial_
size_image_pil
,
get_
spatial_
size_mask
,
get_
spatial_
size_video
,
get_
spatial_
size
,
get_size_bounding_boxes
,
get_size_image_tensor
,
get_size_image_pil
,
get_size_mask
,
get_size_video
,
get_size
,
)
# usort: skip
from
._augment
import
erase
,
erase_image_pil
,
erase_image_tensor
,
erase_video
...
...
torchvision/transforms/v2/functional/_deprecated.py
View file @
312c3d32
...
...
@@ -19,6 +19,6 @@ def to_tensor(inpt: Any) -> torch.Tensor:
def
get_image_size
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
List
[
int
]:
warnings
.
warn
(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_
spatial_
size(...)` which returns `[h, w]` instead of `[w, h]`."
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
)
return
_F
.
get_image_size
(
inpt
)
torchvision/transforms/v2/functional/_geometry.py
View file @
312c3d32
...
...
@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_
spatial_
size_image_pil
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
is_simple_tensor
...
...
@@ -52,18 +52,18 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def
horizontal_flip_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
]
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
canvas
_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
shape
=
bounding_boxes
.
shape
bounding_boxes
=
bounding_boxes
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
bounding_boxes
[:,
[
2
,
0
]]
=
bounding_boxes
[:,
[
0
,
2
]].
sub_
(
spatial
_size
[
1
]).
neg_
()
bounding_boxes
[:,
[
2
,
0
]]
=
bounding_boxes
[:,
[
0
,
2
]].
sub_
(
canvas
_size
[
1
]).
neg_
()
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
bounding_boxes
[:,
0
].
add_
(
bounding_boxes
[:,
2
]).
sub_
(
spatial
_size
[
1
]).
neg_
()
bounding_boxes
[:,
0
].
add_
(
bounding_boxes
[:,
2
]).
sub_
(
canvas
_size
[
1
]).
neg_
()
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes
[:,
0
].
sub_
(
spatial
_size
[
1
]).
neg_
()
bounding_boxes
[:,
0
].
sub_
(
canvas
_size
[
1
]).
neg_
()
return
bounding_boxes
.
reshape
(
shape
)
...
...
@@ -102,18 +102,18 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def
vertical_flip_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
]
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
canvas
_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
shape
=
bounding_boxes
.
shape
bounding_boxes
=
bounding_boxes
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
bounding_boxes
[:,
[
1
,
3
]]
=
bounding_boxes
[:,
[
3
,
1
]].
sub_
(
spatial
_size
[
0
]).
neg_
()
bounding_boxes
[:,
[
1
,
3
]]
=
bounding_boxes
[:,
[
3
,
1
]].
sub_
(
canvas
_size
[
0
]).
neg_
()
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
bounding_boxes
[:,
1
].
add_
(
bounding_boxes
[:,
3
]).
sub_
(
spatial
_size
[
0
]).
neg_
()
bounding_boxes
[:,
1
].
add_
(
bounding_boxes
[:,
3
]).
sub_
(
canvas
_size
[
0
]).
neg_
()
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes
[:,
1
].
sub_
(
spatial
_size
[
0
]).
neg_
()
bounding_boxes
[:,
1
].
sub_
(
canvas
_size
[
0
]).
neg_
()
return
bounding_boxes
.
reshape
(
shape
)
...
...
@@ -146,7 +146,7 @@ vflip = vertical_flip
def
_compute_resized_output_size
(
spatial
_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
canvas
_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
List
[
int
]:
if
isinstance
(
size
,
int
):
size
=
[
size
]
...
...
@@ -155,7 +155,7 @@ def _compute_resized_output_size(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return
__compute_resized_output_size
(
spatial
_size
,
size
=
size
,
max_size
=
max_size
)
return
__compute_resized_output_size
(
canvas
_size
,
size
=
size
,
max_size
=
max_size
)
def
resize_image_tensor
(
...
...
@@ -275,13 +275,13 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def
resize_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
spatial
_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
bounding_boxes
:
torch
.
Tensor
,
canvas
_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
old_height
,
old_width
=
spatial
_size
new_height
,
new_width
=
_compute_resized_output_size
(
spatial
_size
,
size
=
size
,
max_size
=
max_size
)
old_height
,
old_width
=
canvas
_size
new_height
,
new_width
=
_compute_resized_output_size
(
canvas
_size
,
size
=
size
,
max_size
=
max_size
)
if
(
new_height
,
new_width
)
==
(
old_height
,
old_width
):
return
bounding_boxes
,
spatial
_size
return
bounding_boxes
,
canvas
_size
w_ratio
=
new_width
/
old_width
h_ratio
=
new_height
/
old_height
...
...
@@ -643,7 +643,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if
center
is
None
:
height
,
width
=
get_
spatial_
size_image_pil
(
image
)
height
,
width
=
get_size_image_pil
(
image
)
center
=
[
width
*
0.5
,
height
*
0.5
]
matrix
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
)
...
...
@@ -653,7 +653,7 @@ def affine_image_pil(
def
_affine_bounding_boxes_with_expand
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
...
...
@@ -662,7 +662,7 @@ def _affine_bounding_boxes_with_expand(
expand
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
if
bounding_boxes
.
numel
()
==
0
:
return
bounding_boxes
,
spatial
_size
return
bounding_boxes
,
canvas
_size
original_shape
=
bounding_boxes
.
shape
original_dtype
=
bounding_boxes
.
dtype
...
...
@@ -680,7 +680,7 @@ def _affine_bounding_boxes_with_expand(
)
if
center
is
None
:
height
,
width
=
spatial
_size
height
,
width
=
canvas
_size
center
=
[
width
*
0.5
,
height
*
0.5
]
affine_vector
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
,
inverted
=
False
)
...
...
@@ -710,7 +710,7 @@ def _affine_bounding_boxes_with_expand(
if
expand
:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height
,
width
=
spatial
_size
height
,
width
=
canvas
_size
points
=
torch
.
tensor
(
[
[
0.0
,
0.0
,
1.0
],
...
...
@@ -728,21 +728,21 @@ def _affine_bounding_boxes_with_expand(
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector
=
_get_inverse_affine_matrix
([
0.0
,
0.0
],
angle
,
translate
,
scale
,
shear
)
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
spatial
_size
=
(
new_height
,
new_width
)
canvas
_size
=
(
new_height
,
new_width
)
out_bboxes
=
clamp_bounding_boxes
(
out_bboxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial
_size
)
out_bboxes
=
clamp_bounding_boxes
(
out_bboxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas
_size
)
out_bboxes
=
convert_format_bounding_boxes
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
out_bboxes
=
out_bboxes
.
to
(
original_dtype
)
return
out_bboxes
,
spatial
_size
return
out_bboxes
,
canvas
_size
def
affine_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
...
...
@@ -752,7 +752,7 @@ def affine_bounding_boxes(
out_box
,
_
=
_affine_bounding_boxes_with_expand
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial
_size
,
canvas_size
=
canvas
_size
,
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
...
...
@@ -930,7 +930,7 @@ def rotate_image_pil(
def
rotate_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
...
...
@@ -941,7 +941,7 @@ def rotate_bounding_boxes(
return
_affine_bounding_boxes_with_expand
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial
_size
,
canvas_size
=
canvas
_size
,
angle
=-
angle
,
translate
=
[
0.0
,
0.0
],
scale
=
1.0
,
...
...
@@ -1168,7 +1168,7 @@ def pad_mask(
def
pad_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
padding
:
List
[
int
],
padding_mode
:
str
=
"constant"
,
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
...
...
@@ -1184,12 +1184,12 @@ def pad_bounding_boxes(
pad
=
[
left
,
top
,
0
,
0
]
bounding_boxes
=
bounding_boxes
+
torch
.
tensor
(
pad
,
dtype
=
bounding_boxes
.
dtype
,
device
=
bounding_boxes
.
device
)
height
,
width
=
spatial
_size
height
,
width
=
canvas
_size
height
+=
top
+
bottom
width
+=
left
+
right
spatial
_size
=
(
height
,
width
)
canvas
_size
=
(
height
,
width
)
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial
_size
),
spatial
_size
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas
_size
),
canvas
_size
def
pad_video
(
...
...
@@ -1261,9 +1261,9 @@ def crop_bounding_boxes(
sub
=
[
left
,
top
,
0
,
0
]
bounding_boxes
=
bounding_boxes
-
torch
.
tensor
(
sub
,
dtype
=
bounding_boxes
.
dtype
,
device
=
bounding_boxes
.
device
)
spatial
_size
=
(
height
,
width
)
canvas
_size
=
(
height
,
width
)
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial
_size
),
spatial
_size
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas
_size
),
canvas
_size
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
...
...
@@ -1412,7 +1412,7 @@ def perspective_image_pil(
def
perspective_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
coefficients
:
Optional
[
List
[
float
]]
=
None
,
...
...
@@ -1493,7 +1493,7 @@ def perspective_bounding_boxes(
out_bboxes
=
clamp_bounding_boxes
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_boxes
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial
_size
,
canvas_size
=
canvas
_size
,
)
# out_bboxes should be of shape [N boxes, 4]
...
...
@@ -1651,7 +1651,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def
elastic_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
displacement
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
bounding_boxes
.
numel
()
==
0
:
...
...
@@ -1670,7 +1670,7 @@ def elastic_bounding_boxes(
convert_format_bounding_boxes
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
id_grid
=
_create_identity_grid
(
spatial
_size
,
device
=
device
,
dtype
=
dtype
)
id_grid
=
_create_identity_grid
(
canvas
_size
,
device
=
device
,
dtype
=
dtype
)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid
=
id_grid
.
sub_
(
displacement
)
...
...
@@ -1683,7 +1683,7 @@ def elastic_bounding_boxes(
index_x
,
index_y
=
index_xy
[:,
0
],
index_xy
[:,
1
]
# Transform points:
t_size
=
torch
.
tensor
(
spatial
_size
[::
-
1
],
device
=
displacement
.
device
,
dtype
=
displacement
.
dtype
)
t_size
=
torch
.
tensor
(
canvas
_size
[::
-
1
],
device
=
displacement
.
device
,
dtype
=
displacement
.
dtype
)
transformed_points
=
inv_grid
[
0
,
index_y
,
index_x
,
:].
add_
(
1
).
mul_
(
0.5
*
t_size
).
sub_
(
0.5
)
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
...
...
@@ -1691,7 +1691,7 @@ def elastic_bounding_boxes(
out_bboxes
=
clamp_bounding_boxes
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_boxes
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial
_size
,
canvas_size
=
canvas
_size
,
)
return
convert_format_bounding_boxes
(
...
...
@@ -1804,13 +1804,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@
torch
.
jit
.
unused
def
center_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
output_size
:
List
[
int
])
->
PIL
.
Image
.
Image
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
image_height
,
image_width
=
get_
spatial_
size_image_pil
(
image
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
if
crop_height
>
image_height
or
crop_width
>
image_width
:
padding_ltrb
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
image
=
pad_image_pil
(
image
,
padding_ltrb
,
fill
=
0
)
image_height
,
image_width
=
get_
spatial_
size_image_pil
(
image
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
if
crop_width
==
image_width
and
crop_height
==
image_height
:
return
image
...
...
@@ -1821,11 +1821,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
def
center_crop_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
],
canvas
_size
:
Tuple
[
int
,
int
],
output_size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
*
spatial
_size
)
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
*
canvas
_size
)
return
crop_bounding_boxes
(
bounding_boxes
,
format
,
top
=
crop_top
,
left
=
crop_left
,
height
=
crop_height
,
width
=
crop_width
)
...
...
@@ -1905,7 +1905,7 @@ def resized_crop_bounding_boxes(
size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
bounding_boxes
,
_
=
crop_bounding_boxes
(
bounding_boxes
,
format
,
top
,
left
,
height
,
width
)
return
resize_bounding_boxes
(
bounding_boxes
,
spatial
_size
=
(
height
,
width
),
size
=
size
)
return
resize_bounding_boxes
(
bounding_boxes
,
canvas
_size
=
(
height
,
width
),
size
=
size
)
def
resized_crop_mask
(
...
...
@@ -2000,7 +2000,7 @@ def five_crop_image_pil(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
crop_height
,
crop_width
=
_parse_five_crop_size
(
size
)
image_height
,
image_width
=
get_
spatial_
size_image_pil
(
image
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
if
crop_width
>
image_width
or
crop_height
>
image_height
:
raise
ValueError
(
f
"Requested crop size
{
size
}
is bigger than input size
{
(
image_height
,
image_width
)
}
"
)
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
312c3d32
...
...
@@ -26,23 +26,29 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil
=
_FP
.
get_dimensions
def
get_dimensions_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_dimensions_image_tensor
(
video
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
channels
=
inpt
.
num_channels
height
,
width
=
inpt
.
spatial_size
return
[
channels
,
height
,
width
]
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_dimensions_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_dimensions_image_tensor
,
datapoints
.
Video
:
get_dimensions_video
,
PIL
.
Image
.
Image
:
get_dimensions_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
def
get_num_channels_image_tensor
(
image
:
torch
.
Tensor
)
->
int
:
...
...
@@ -69,15 +75,19 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
num_channels
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_num_channels_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_num_channels_image_tensor
,
datapoints
.
Video
:
get_num_channels_video
,
PIL
.
Image
.
Image
:
get_num_channels_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...
...
@@ -85,7 +95,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
get_image_num_channels
=
get_num_channels
def
get_
spatial_
size_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_size_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
hw
=
list
(
image
.
shape
[
-
2
:])
ndims
=
len
(
hw
)
if
ndims
==
2
:
...
...
@@ -95,39 +105,48 @@ def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
@
torch
.
jit
.
unused
def
get_
spatial_
size_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
List
[
int
]:
def
get_size_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
List
[
int
]:
width
,
height
=
_FP
.
get_image_size
(
image
)
return
[
height
,
width
]
def
get_
spatial_
size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_
spatial_
size_image_tensor
(
video
)
def
get_size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image_tensor
(
video
)
def
get_
spatial_
size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_
spatial_
size_image_tensor
(
mask
)
def
get_size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image_tensor
(
mask
)
@
torch
.
jit
.
unused
def
get_
spatial_
size_bounding_boxes
(
bounding_box
es
:
datapoints
.
BoundingBoxes
)
->
List
[
int
]:
return
list
(
bounding_box
es
.
spatial
_size
)
def
get_size_bounding_boxes
(
bounding_box
:
datapoints
.
BoundingBoxes
)
->
List
[
int
]:
return
list
(
bounding_box
.
canvas
_size
)
def
get_
spatial_
size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_
spatial_
size
)
_log_api_usage_once
(
get_size
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_spatial_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
return
list
(
inpt
.
spatial_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_spatial_size_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
return
get_size_image_tensor
(
inpt
)
# TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with
# https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have
# a method on the datapoint class
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_size_image_tensor
,
datapoints
.
BoundingBoxes
:
get_size_bounding_boxes
,
datapoints
.
Mask
:
get_size_mask
,
datapoints
.
Video
:
get_size_video
,
PIL
.
Image
.
Image
:
get_size_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
def
get_num_frames_video
(
video
:
torch
.
Tensor
)
->
int
:
...
...
@@ -141,7 +160,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
return
inpt
.
num_frames
return
get_
num_frames
_video
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a `Video` datapoint, but got
{
type
(
inpt
)
}
instead."
)
...
...
@@ -240,7 +259,7 @@ def convert_format_bounding_boxes(
def
_clamp_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial
_size
:
Tuple
[
int
,
int
]
bounding_boxes
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
canvas
_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
...
...
@@ -249,8 +268,8 @@ def _clamp_bounding_boxes(
xyxy_boxes
=
convert_format_bounding_boxes
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
spatial
_size
[
1
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
spatial
_size
[
0
])
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
canvas
_size
[
1
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
canvas
_size
[
0
])
out_boxes
=
convert_format_bounding_boxes
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
...
...
@@ -260,21 +279,20 @@ def _clamp_bounding_boxes(
def
clamp_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
spatial
_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
canvas
_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_boxes
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
format
is
None
or
spatial_size
is
None
:
raise
ValueError
(
"For simple tensor inputs, `format` and `spatial_size` has to be passed."
)
return
_clamp_bounding_boxes
(
inpt
,
format
=
format
,
spatial_size
=
spatial_size
)
if
format
is
None
or
canvas_size
is
None
:
raise
ValueError
(
"For simple tensor inputs, `format` and `canvas_size` has to be passed."
)
return
_clamp_bounding_boxes
(
inpt
,
format
=
format
,
canvas_size
=
canvas_size
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
):
if
format
is
not
None
or
spatial_size
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `format` and `spatial_size` must not be passed."
)
output
=
_clamp_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
)
if
format
is
not
None
or
canvas_size
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `format` and `canvas_size` must not be passed."
)
output
=
_clamp_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
else
:
raise
TypeError
(
...
...
torchvision/transforms/v2/utils.py
View file @
312c3d32
...
...
@@ -6,15 +6,15 @@ import PIL.Image
from
torchvision
import
datapoints
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_
spatial_
size
,
is_simple_tensor
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_simple_tensor
def
query_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
)]
if
not
bounding_boxes
:
raise
TypeError
(
"No bounding box
was
found in the sample"
)
raise
TypeError
(
"No bounding box
es were
found in the sample"
)
elif
len
(
bounding_boxes
)
>
1
:
raise
ValueError
(
"Found multiple bounding boxes in the sample"
)
raise
ValueError
(
"Found multiple bounding boxes
instances
in the sample"
)
return
bounding_boxes
.
pop
()
...
...
@@ -22,7 +22,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
or
is_simple_tensor
(
inpt
)
if
check_type
(
inpt
,
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
...
...
@@ -32,14 +32,21 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
return
c
,
h
,
w
def
query_
spatial_
size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
def
query_size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
sizes
=
{
tuple
(
get_
spatial_
size
(
inpt
))
tuple
(
get_size
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
)
if
check_type
(
inpt
,
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
),
)
or
is_simple_tensor
(
inpt
)
}
if
not
sizes
:
raise
TypeError
(
"No image, video, mask or bounding box was found in the sample"
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment