Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
332bff93
Unverified
Commit
332bff93
authored
Jul 31, 2023
by
Nicolas Hug
Committed by
GitHub
Jul 31, 2023
Browse files
Renaming: `BoundingBox` -> `BoundingBoxes` (#7778)
parent
d4e5aa21
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
233 additions
and
227 deletions
+233
-227
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+2
-2
torchvision/prototype/datasets/_builtin/celeba.py
torchvision/prototype/datasets/_builtin/celeba.py
+4
-4
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+2
-2
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+5
-5
torchvision/prototype/datasets/_builtin/gtsrb.py
torchvision/prototype/datasets/_builtin/gtsrb.py
+3
-3
torchvision/prototype/datasets/_builtin/stanford_cars.py
torchvision/prototype/datasets/_builtin/stanford_cars.py
+2
-2
torchvision/prototype/datasets/_builtin/voc.py
torchvision/prototype/datasets/_builtin/voc.py
+2
-2
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+8
-8
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+12
-10
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+2
-2
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+1
-1
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+1
-1
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+28
-28
torchvision/transforms/v2/_meta.py
torchvision/transforms/v2/_meta.py
+8
-8
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+11
-11
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+14
-14
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+91
-89
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+34
-32
torchvision/transforms/v2/utils.py
torchvision/transforms/v2/utils.py
+3
-3
No files found.
torchvision/prototype/datasets/_builtin/caltech.py
View file @
332bff93
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -112,7 +112,7 @@ class Caltech101(Dataset):
...
@@ -112,7 +112,7 @@ class Caltech101(Dataset):
image_path
=
image_path
,
image_path
=
image_path
,
image
=
image
,
image
=
image
,
ann_path
=
ann_path
,
ann_path
=
ann_path
,
bounding_box
=
BoundingBox
(
bounding_box
es
=
BoundingBox
es
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
,
format
=
"xyxy"
,
spatial_size
=
image
.
spatial_size
,
spatial_size
=
image
.
spatial_size
,
...
...
torchvision/prototype/datasets/_builtin/celeba.py
View file @
332bff93
...
@@ -4,7 +4,7 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
...
@@ -4,7 +4,7 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import
torch
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -137,15 +137,15 @@ class CelebA(Dataset):
...
@@ -137,15 +137,15 @@ class CelebA(Dataset):
path
,
buffer
=
image_data
path
,
buffer
=
image_data
image
=
EncodedImage
.
from_file
(
buffer
)
image
=
EncodedImage
.
from_file
(
buffer
)
(
_
,
identity
),
(
_
,
attributes
),
(
_
,
bounding_box
),
(
_
,
landmarks
)
=
ann_data
(
_
,
identity
),
(
_
,
attributes
),
(
_
,
bounding_box
es
),
(
_
,
landmarks
)
=
ann_data
return
dict
(
return
dict
(
path
=
path
,
path
=
path
,
image
=
image
,
image
=
image
,
identity
=
Label
(
int
(
identity
[
"identity"
])),
identity
=
Label
(
int
(
identity
[
"identity"
])),
attributes
=
{
attr
:
value
==
"1"
for
attr
,
value
in
attributes
.
items
()},
attributes
=
{
attr
:
value
==
"1"
for
attr
,
value
in
attributes
.
items
()},
bounding_box
=
BoundingBox
(
bounding_box
es
=
BoundingBox
es
(
[
int
(
bounding_box
[
key
])
for
key
in
(
"x_1"
,
"y_1"
,
"width"
,
"height"
)],
[
int
(
bounding_box
es
[
key
])
for
key
in
(
"x_1"
,
"y_1"
,
"width"
,
"height"
)],
format
=
"xywh"
,
format
=
"xywh"
,
spatial_size
=
image
.
spatial_size
,
spatial_size
=
image
.
spatial_size
,
),
),
...
...
torchvision/prototype/datasets/_builtin/coco.py
View file @
332bff93
...
@@ -14,7 +14,7 @@ from torchdata.datapipes.iter import (
...
@@ -14,7 +14,7 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
UnBatcher
,
UnBatcher
,
)
)
from
torchvision.datapoints
import
BoundingBox
,
Mask
from
torchvision.datapoints
import
BoundingBox
es
,
Mask
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -126,7 +126,7 @@ class Coco(Dataset):
...
@@ -126,7 +126,7 @@ class Coco(Dataset):
),
),
areas
=
torch
.
as_tensor
([
ann
[
"area"
]
for
ann
in
anns
]),
areas
=
torch
.
as_tensor
([
ann
[
"area"
]
for
ann
in
anns
]),
crowds
=
torch
.
as_tensor
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
crowds
=
torch
.
as_tensor
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
bounding_boxes
=
BoundingBox
(
bounding_boxes
=
BoundingBox
es
(
[
ann
[
"bbox"
]
for
ann
in
anns
],
[
ann
[
"bbox"
]
for
ann
in
anns
],
format
=
"xywh"
,
format
=
"xywh"
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
...
...
torchvision/prototype/datasets/_builtin/cub200.py
View file @
332bff93
...
@@ -15,7 +15,7 @@ from torchdata.datapipes.iter import (
...
@@ -15,7 +15,7 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -134,11 +134,11 @@ class CUB200(Dataset):
...
@@ -134,11 +134,11 @@ class CUB200(Dataset):
def
_2011_prepare_ann
(
def
_2011_prepare_ann
(
self
,
data
:
Tuple
[
str
,
Tuple
[
List
[
str
],
Tuple
[
str
,
BinaryIO
]]],
spatial_size
:
Tuple
[
int
,
int
]
self
,
data
:
Tuple
[
str
,
Tuple
[
List
[
str
],
Tuple
[
str
,
BinaryIO
]]],
spatial_size
:
Tuple
[
int
,
int
]
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
_
,
(
bounding_box_data
,
segmentation_data
)
=
data
_
,
(
bounding_box
es
_data
,
segmentation_data
)
=
data
segmentation_path
,
segmentation_buffer
=
segmentation_data
segmentation_path
,
segmentation_buffer
=
segmentation_data
return
dict
(
return
dict
(
bounding_box
=
BoundingBox
(
bounding_box
es
=
BoundingBox
es
(
[
float
(
part
)
for
part
in
bounding_box_data
[
1
:]],
format
=
"xywh"
,
spatial_size
=
spatial_size
[
float
(
part
)
for
part
in
bounding_box
es
_data
[
1
:]],
format
=
"xywh"
,
spatial_size
=
spatial_size
),
),
segmentation_path
=
segmentation_path
,
segmentation_path
=
segmentation_path
,
segmentation
=
EncodedImage
.
from_file
(
segmentation_buffer
),
segmentation
=
EncodedImage
.
from_file
(
segmentation_buffer
),
...
@@ -158,7 +158,7 @@ class CUB200(Dataset):
...
@@ -158,7 +158,7 @@ class CUB200(Dataset):
content
=
read_mat
(
buffer
)
content
=
read_mat
(
buffer
)
return
dict
(
return
dict
(
ann_path
=
path
,
ann_path
=
path
,
bounding_box
=
BoundingBox
(
bounding_box
es
=
BoundingBox
es
(
[
int
(
content
[
"bbox"
][
coord
])
for
coord
in
(
"left"
,
"bottom"
,
"right"
,
"top"
)],
[
int
(
content
[
"bbox"
][
coord
])
for
coord
in
(
"left"
,
"bottom"
,
"right"
,
"top"
)],
format
=
"xyxy"
,
format
=
"xyxy"
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
...
...
torchvision/prototype/datasets/_builtin/gtsrb.py
View file @
332bff93
...
@@ -2,7 +2,7 @@ import pathlib
...
@@ -2,7 +2,7 @@ import pathlib
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVDictParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
CSVDictParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
Mapper
,
Zipper
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -76,7 +76,7 @@ class GTSRB(Dataset):
...
@@ -76,7 +76,7 @@ class GTSRB(Dataset):
(
path
,
buffer
),
csv_info
=
data
(
path
,
buffer
),
csv_info
=
data
label
=
int
(
csv_info
[
"ClassId"
])
label
=
int
(
csv_info
[
"ClassId"
])
bounding_box
=
BoundingBox
(
bounding_box
es
=
BoundingBox
es
(
[
int
(
csv_info
[
k
])
for
k
in
(
"Roi.X1"
,
"Roi.Y1"
,
"Roi.X2"
,
"Roi.Y2"
)],
[
int
(
csv_info
[
k
])
for
k
in
(
"Roi.X1"
,
"Roi.Y1"
,
"Roi.X2"
,
"Roi.Y2"
)],
format
=
"xyxy"
,
format
=
"xyxy"
,
spatial_size
=
(
int
(
csv_info
[
"Height"
]),
int
(
csv_info
[
"Width"
])),
spatial_size
=
(
int
(
csv_info
[
"Height"
]),
int
(
csv_info
[
"Width"
])),
...
@@ -86,7 +86,7 @@ class GTSRB(Dataset):
...
@@ -86,7 +86,7 @@ class GTSRB(Dataset):
"path"
:
path
,
"path"
:
path
,
"image"
:
EncodedImage
.
from_file
(
buffer
),
"image"
:
EncodedImage
.
from_file
(
buffer
),
"label"
:
Label
(
label
,
categories
=
self
.
_categories
),
"label"
:
Label
(
label
,
categories
=
self
.
_categories
),
"bounding_box"
:
bounding_box
,
"bounding_box
es
"
:
bounding_box
es
,
}
}
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
...
...
torchvision/prototype/datasets/_builtin/stanford_cars.py
View file @
332bff93
...
@@ -2,7 +2,7 @@ import pathlib
...
@@ -2,7 +2,7 @@ import pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
Iterator
,
List
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
Iterator
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
,
Zipper
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
...
@@ -90,7 +90,7 @@ class StanfordCars(Dataset):
...
@@ -90,7 +90,7 @@ class StanfordCars(Dataset):
path
=
path
,
path
=
path
,
image
=
image
,
image
=
image
,
label
=
Label
(
target
[
4
]
-
1
,
categories
=
self
.
_categories
),
label
=
Label
(
target
[
4
]
-
1
,
categories
=
self
.
_categories
),
bounding_box
=
BoundingBox
(
target
[:
4
],
format
=
"xyxy"
,
spatial_size
=
image
.
spatial_size
),
bounding_box
es
=
BoundingBox
es
(
target
[:
4
],
format
=
"xyxy"
,
spatial_size
=
image
.
spatial_size
),
)
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
...
...
torchvision/prototype/datasets/_builtin/voc.py
View file @
332bff93
...
@@ -5,7 +5,7 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
...
@@ -5,7 +5,7 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from
xml.etree
import
ElementTree
from
xml.etree
import
ElementTree
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.datapoints
import
BoundingBox
from
torchvision.datapoints
import
BoundingBox
es
from
torchvision.datasets
import
VOCDetection
from
torchvision.datasets
import
VOCDetection
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
...
@@ -103,7 +103,7 @@ class VOC(Dataset):
...
@@ -103,7 +103,7 @@ class VOC(Dataset):
anns
=
self
.
_parse_detection_ann
(
buffer
)
anns
=
self
.
_parse_detection_ann
(
buffer
)
instances
=
anns
[
"object"
]
instances
=
anns
[
"object"
]
return
dict
(
return
dict
(
bounding_boxes
=
BoundingBox
(
bounding_boxes
=
BoundingBox
es
(
[
[
[
int
(
instance
[
"bndbox"
][
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
[
int
(
instance
[
"bndbox"
][
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
for
instance
in
instances
for
instance
in
instances
...
...
torchvision/prototype/transforms/_augment.py
View file @
332bff93
...
@@ -26,7 +26,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
...
@@ -26,7 +26,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
and
has_any
(
flat_inputs
,
proto_datapoints
.
OneHotLabel
)
and
has_any
(
flat_inputs
,
proto_datapoints
.
OneHotLabel
)
):
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() is only defined for tensor images/videos and one-hot labels."
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() is only defined for tensor images/videos and one-hot labels."
)
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
,
proto_datapoints
.
Label
):
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
,
proto_datapoints
.
Label
):
raise
TypeError
(
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes, masks and plain labels."
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes, masks and plain labels."
)
)
...
@@ -175,7 +175,7 @@ class SimpleCopyPaste(Transform):
...
@@ -175,7 +175,7 @@ class SimpleCopyPaste(Transform):
# There is a similar +1 in other reference implementations:
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes
[:,
2
:]
+=
1
xyxy_boxes
[:,
2
:]
+=
1
boxes
=
F
.
convert_format_bounding_box
(
boxes
=
F
.
convert_format_bounding_box
es
(
xyxy_boxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
bbox_format
,
inplace
=
True
xyxy_boxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
bbox_format
,
inplace
=
True
)
)
out_target
[
"boxes"
]
=
torch
.
cat
([
boxes
,
paste_boxes
])
out_target
[
"boxes"
]
=
torch
.
cat
([
boxes
,
paste_boxes
])
...
@@ -184,7 +184,7 @@ class SimpleCopyPaste(Transform):
...
@@ -184,7 +184,7 @@ class SimpleCopyPaste(Transform):
out_target
[
"labels"
]
=
torch
.
cat
([
labels
,
paste_labels
])
out_target
[
"labels"
]
=
torch
.
cat
([
labels
,
paste_labels
])
# Check for degenerated boxes and remove them
# Check for degenerated boxes and remove them
boxes
=
F
.
convert_format_bounding_box
(
boxes
=
F
.
convert_format_bounding_box
es
(
out_target
[
"boxes"
],
old_format
=
bbox_format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
out_target
[
"boxes"
],
old_format
=
bbox_format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
)
degenerate_boxes
=
boxes
[:,
2
:]
<=
boxes
[:,
:
2
]
degenerate_boxes
=
boxes
[:,
2
:]
<=
boxes
[:,
:
2
]
...
@@ -201,14 +201,14 @@ class SimpleCopyPaste(Transform):
...
@@ -201,14 +201,14 @@ class SimpleCopyPaste(Transform):
self
,
flat_sample
:
List
[
Any
]
self
,
flat_sample
:
List
[
Any
]
)
->
Tuple
[
List
[
datapoints
.
_TensorImageType
],
List
[
Dict
[
str
,
Any
]]]:
)
->
Tuple
[
List
[
datapoints
.
_TensorImageType
],
List
[
Dict
[
str
,
Any
]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
# with List[image], List[BoundingBox
es
], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
for
obj
in
flat_sample
:
for
obj
in
flat_sample
:
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_simple_tensor
(
obj
):
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_simple_tensor
(
obj
):
images
.
append
(
obj
)
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
F
.
to_image_tensor
(
obj
))
images
.
append
(
F
.
to_image_tensor
(
obj
))
elif
isinstance
(
obj
,
datapoints
.
BoundingBox
):
elif
isinstance
(
obj
,
datapoints
.
BoundingBox
es
):
bboxes
.
append
(
obj
)
bboxes
.
append
(
obj
)
elif
isinstance
(
obj
,
datapoints
.
Mask
):
elif
isinstance
(
obj
,
datapoints
.
Mask
):
masks
.
append
(
obj
)
masks
.
append
(
obj
)
...
@@ -218,7 +218,7 @@ class SimpleCopyPaste(Transform):
...
@@ -218,7 +218,7 @@ class SimpleCopyPaste(Transform):
if
not
(
len
(
images
)
==
len
(
bboxes
)
==
len
(
masks
)
==
len
(
labels
)):
if
not
(
len
(
images
)
==
len
(
bboxes
)
==
len
(
masks
)
==
len
(
labels
)):
raise
TypeError
(
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain equal sized list of Images, "
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Masks and Labels or OneHotLabels."
"BoundingBoxes
es
, Masks and Labels or OneHotLabels."
)
)
targets
=
[]
targets
=
[]
...
@@ -244,8 +244,8 @@ class SimpleCopyPaste(Transform):
...
@@ -244,8 +244,8 @@ class SimpleCopyPaste(Transform):
elif
is_simple_tensor
(
obj
):
elif
is_simple_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
flat_sample
[
i
]
=
output_images
[
c0
]
c0
+=
1
c0
+=
1
elif
isinstance
(
obj
,
datapoints
.
BoundingBox
):
elif
isinstance
(
obj
,
datapoints
.
BoundingBox
es
):
flat_sample
[
i
]
=
datapoints
.
BoundingBox
.
wrap_like
(
obj
,
output_targets
[
c1
][
"boxes"
])
flat_sample
[
i
]
=
datapoints
.
BoundingBox
es
.
wrap_like
(
obj
,
output_targets
[
c1
][
"boxes"
])
c1
+=
1
c1
+=
1
elif
isinstance
(
obj
,
datapoints
.
Mask
):
elif
isinstance
(
obj
,
datapoints
.
Mask
):
flat_sample
[
i
]
=
datapoints
.
Mask
.
wrap_like
(
obj
,
output_targets
[
c2
][
"masks"
])
flat_sample
[
i
]
=
datapoints
.
Mask
.
wrap_like
(
obj
,
output_targets
[
c2
][
"masks"
])
...
...
torchvision/prototype/transforms/_geometry.py
View file @
332bff93
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_box
,
query_spatial_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_box
es
,
query_spatial_size
class
FixedSizeCrop
(
Transform
):
class
FixedSizeCrop
(
Transform
):
...
@@ -39,9 +39,9 @@ class FixedSizeCrop(Transform):
...
@@ -39,9 +39,9 @@ class FixedSizeCrop(Transform):
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain an tensor or PIL image or a Video."
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain an tensor or PIL image or a Video."
)
)
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
)
and
not
has_any
(
flat_inputs
,
Label
,
OneHotLabel
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
es
)
and
not
has_any
(
flat_inputs
,
Label
,
OneHotLabel
):
raise
TypeError
(
raise
TypeError
(
f
"If a BoundingBox is contained in the input sample, "
f
"If a BoundingBox
es
is contained in the input sample, "
f
"
{
type
(
self
).
__name__
}
() also requires it to contain a Label or OneHotLabel."
f
"
{
type
(
self
).
__name__
}
() also requires it to contain a Label or OneHotLabel."
)
)
...
@@ -61,13 +61,13 @@ class FixedSizeCrop(Transform):
...
@@ -61,13 +61,13 @@ class FixedSizeCrop(Transform):
bounding_boxes
:
Optional
[
torch
.
Tensor
]
bounding_boxes
:
Optional
[
torch
.
Tensor
]
try
:
try
:
bounding_boxes
=
query_bounding_box
(
flat_inputs
)
bounding_boxes
=
query_bounding_box
es
(
flat_inputs
)
except
ValueError
:
except
ValueError
:
bounding_boxes
=
None
bounding_boxes
=
None
if
needs_crop
and
bounding_boxes
is
not
None
:
if
needs_crop
and
bounding_boxes
is
not
None
:
format
=
bounding_boxes
.
format
format
=
bounding_boxes
.
format
bounding_boxes
,
spatial_size
=
F
.
crop_bounding_box
(
bounding_boxes
,
spatial_size
=
F
.
crop_bounding_box
es
(
bounding_boxes
.
as_subclass
(
torch
.
Tensor
),
bounding_boxes
.
as_subclass
(
torch
.
Tensor
),
format
=
format
,
format
=
format
,
top
=
top
,
top
=
top
,
...
@@ -75,8 +75,8 @@ class FixedSizeCrop(Transform):
...
@@ -75,8 +75,8 @@ class FixedSizeCrop(Transform):
height
=
new_height
,
height
=
new_height
,
width
=
new_width
,
width
=
new_width
,
)
)
bounding_boxes
=
F
.
clamp_bounding_box
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial_size
)
bounding_boxes
=
F
.
clamp_bounding_box
es
(
bounding_boxes
,
format
=
format
,
spatial_size
=
spatial_size
)
height_and_width
=
F
.
convert_format_bounding_box
(
height_and_width
=
F
.
convert_format_bounding_box
es
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYWH
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYWH
)[...,
2
:]
)[...,
2
:]
is_valid
=
torch
.
all
(
height_and_width
>
0
,
dim
=-
1
)
is_valid
=
torch
.
all
(
height_and_width
>
0
,
dim
=-
1
)
...
@@ -112,10 +112,12 @@ class FixedSizeCrop(Transform):
...
@@ -112,10 +112,12 @@ class FixedSizeCrop(Transform):
if
params
[
"is_valid"
]
is
not
None
:
if
params
[
"is_valid"
]
is
not
None
:
if
isinstance
(
inpt
,
(
Label
,
OneHotLabel
,
datapoints
.
Mask
)):
if
isinstance
(
inpt
,
(
Label
,
OneHotLabel
,
datapoints
.
Mask
)):
inpt
=
inpt
.
wrap_like
(
inpt
,
inpt
[
params
[
"is_valid"
]])
# type: ignore[arg-type]
inpt
=
inpt
.
wrap_like
(
inpt
,
inpt
[
params
[
"is_valid"
]])
# type: ignore[arg-type]
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
es
):
inpt
=
datapoints
.
BoundingBox
.
wrap_like
(
inpt
=
datapoints
.
BoundingBox
es
.
wrap_like
(
inpt
,
inpt
,
F
.
clamp_bounding_box
(
inpt
[
params
[
"is_valid"
]],
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
),
F
.
clamp_bounding_boxes
(
inpt
[
params
[
"is_valid"
]],
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
),
)
)
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
...
...
torchvision/transforms/v2/__init__.py
View file @
332bff93
...
@@ -39,7 +39,7 @@ from ._geometry import (
...
@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter
,
ScaleJitter
,
TenCrop
,
TenCrop
,
)
)
from
._meta
import
ClampBoundingBox
,
ConvertBoundingBoxFormat
from
._meta
import
ClampBoundingBox
es
,
ConvertBoundingBoxFormat
from
._misc
import
(
from
._misc
import
(
ConvertImageDtype
,
ConvertImageDtype
,
GaussianBlur
,
GaussianBlur
,
...
@@ -47,7 +47,7 @@ from ._misc import (
...
@@ -47,7 +47,7 @@ from ._misc import (
Lambda
,
Lambda
,
LinearTransformation
,
LinearTransformation
,
Normalize
,
Normalize
,
SanitizeBoundingBox
,
SanitizeBoundingBox
es
,
ToDtype
,
ToDtype
,
)
)
from
._temporal
import
UniformTemporalSubsample
from
._temporal
import
UniformTemporalSubsample
...
...
torchvision/transforms/v2/_augment.py
View file @
332bff93
...
@@ -155,7 +155,7 @@ class _BaseMixupCutmix(Transform):
...
@@ -155,7 +155,7 @@ class _BaseMixupCutmix(Transform):
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
):
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes and masks."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes and masks."
)
labels
=
self
.
_labels_getter
(
inputs
)
labels
=
self
.
_labels_getter
(
inputs
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
332bff93
...
@@ -34,7 +34,7 @@ class _AutoAugmentBase(Transform):
...
@@ -34,7 +34,7 @@ class _AutoAugmentBase(Transform):
def
_flatten_and_extract_image_or_video
(
def
_flatten_and_extract_image_or_video
(
self
,
self
,
inputs
:
Any
,
inputs
:
Any
,
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]]:
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
...
...
torchvision/transforms/v2/_geometry.py
View file @
332bff93
...
@@ -22,7 +22,7 @@ from ._utils import (
...
@@ -22,7 +22,7 @@ from ._utils import (
_setup_float_or_seq
,
_setup_float_or_seq
,
_setup_size
,
_setup_size
,
)
)
from
.utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_box
,
query_spatial_size
from
.utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_box
es
,
query_spatial_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
@@ -31,7 +31,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
...
@@ -31,7 +31,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomHorizontalFlip transform
.. v2betastatus:: RandomHorizontalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -51,7 +51,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
...
@@ -51,7 +51,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomVerticalFlip transform
.. v2betastatus:: RandomVerticalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -71,7 +71,7 @@ class Resize(Transform):
...
@@ -71,7 +71,7 @@ class Resize(Transform):
.. v2betastatus:: Resize transform
.. v2betastatus:: Resize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -165,7 +165,7 @@ class CenterCrop(Transform):
...
@@ -165,7 +165,7 @@ class CenterCrop(Transform):
.. v2betastatus:: CenterCrop transform
.. v2betastatus:: CenterCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -193,7 +193,7 @@ class RandomResizedCrop(Transform):
...
@@ -193,7 +193,7 @@ class RandomResizedCrop(Transform):
.. v2betastatus:: RandomResizedCrop transform
.. v2betastatus:: RandomResizedCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -371,8 +371,8 @@ class FiveCrop(Transform):
...
@@ -371,8 +371,8 @@ class FiveCrop(Transform):
return
F
.
five_crop
(
inpt
,
self
.
size
)
return
F
.
five_crop
(
inpt
,
self
.
size
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
):
raise
TypeError
(
f
"BoundingBox'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"BoundingBox
es
'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
class
TenCrop
(
Transform
):
class
TenCrop
(
Transform
):
...
@@ -414,8 +414,8 @@ class TenCrop(Transform):
...
@@ -414,8 +414,8 @@ class TenCrop(Transform):
self
.
vertical_flip
=
vertical_flip
self
.
vertical_flip
=
vertical_flip
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
):
raise
TypeError
(
f
"BoundingBox'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"BoundingBox
es
'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
...
@@ -440,7 +440,7 @@ class Pad(Transform):
...
@@ -440,7 +440,7 @@ class Pad(Transform):
.. v2betastatus:: Pad transform
.. v2betastatus:: Pad transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -525,7 +525,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -525,7 +525,7 @@ class RandomZoomOut(_RandomApplyTransform):
output_height = input_height * r
output_height = input_height * r
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -584,7 +584,7 @@ class RandomRotation(Transform):
...
@@ -584,7 +584,7 @@ class RandomRotation(Transform):
.. v2betastatus:: RandomRotation transform
.. v2betastatus:: RandomRotation transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -657,7 +657,7 @@ class RandomAffine(Transform):
...
@@ -657,7 +657,7 @@ class RandomAffine(Transform):
.. v2betastatus:: RandomAffine transform
.. v2betastatus:: RandomAffine transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -778,7 +778,7 @@ class RandomCrop(Transform):
...
@@ -778,7 +778,7 @@ class RandomCrop(Transform):
.. v2betastatus:: RandomCrop transform
.. v2betastatus:: RandomCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -933,7 +933,7 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -933,7 +933,7 @@ class RandomPerspective(_RandomApplyTransform):
.. v2betastatus:: RandomPerspective transform
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -1019,7 +1019,7 @@ class ElasticTransform(Transform):
...
@@ -1019,7 +1019,7 @@ class ElasticTransform(Transform):
.. v2betastatus:: RandomPerspective transform
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -1110,15 +1110,15 @@ class RandomIoUCrop(Transform):
...
@@ -1110,15 +1110,15 @@ class RandomIoUCrop(Transform):
.. v2betastatus:: RandomIoUCrop transform
.. v2betastatus:: RandomIoUCrop transform
This transformation requires an image or video data and ``datapoints.BoundingBox`` in the input.
This transformation requires an image or video data and ``datapoints.BoundingBox
es
`` in the input.
.. warning::
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox
es
`, either immediately
after or later in the transforms pipeline.
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -1155,7 +1155,7 @@ class RandomIoUCrop(Transform):
...
@@ -1155,7 +1155,7 @@ class RandomIoUCrop(Transform):
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
not
(
if
not
(
has_all
(
flat_inputs
,
datapoints
.
BoundingBox
)
has_all
(
flat_inputs
,
datapoints
.
BoundingBox
es
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
is_simple_tensor
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
is_simple_tensor
)
):
):
raise
TypeError
(
raise
TypeError
(
...
@@ -1165,7 +1165,7 @@ class RandomIoUCrop(Transform):
...
@@ -1165,7 +1165,7 @@ class RandomIoUCrop(Transform):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_h
,
orig_w
=
query_spatial_size
(
flat_inputs
)
orig_h
,
orig_w
=
query_spatial_size
(
flat_inputs
)
bboxes
=
query_bounding_box
(
flat_inputs
)
bboxes
=
query_bounding_box
es
(
flat_inputs
)
while
True
:
while
True
:
# sample an option
# sample an option
...
@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform):
...
@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform):
continue
continue
# check for any valid boxes with centers within the crop area
# check for any valid boxes with centers within the crop area
xyxy_bboxes
=
F
.
convert_format_bounding_box
(
xyxy_bboxes
=
F
.
convert_format_bounding_box
es
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
)
)
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
...
@@ -1220,9 +1220,9 @@ class RandomIoUCrop(Transform):
...
@@ -1220,9 +1220,9 @@ class RandomIoUCrop(Transform):
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
if
isinstance
(
output
,
datapoints
.
BoundingBox
es
):
# We "mark" the invalid boxes as degenreate, and they can be
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBox()
# removed by a later call to SanitizeBoundingBox
es
()
output
[
~
params
[
"is_within_crop_area"
]]
=
0
output
[
~
params
[
"is_within_crop_area"
]]
=
0
return
output
return
output
...
@@ -1235,7 +1235,7 @@ class ScaleJitter(Transform):
...
@@ -1235,7 +1235,7 @@ class ScaleJitter(Transform):
.. v2betastatus:: ScaleJitter transform
.. v2betastatus:: ScaleJitter transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -1301,7 +1301,7 @@ class RandomShortestSize(Transform):
...
@@ -1301,7 +1301,7 @@ class RandomShortestSize(Transform):
.. v2betastatus:: RandomShortestSize transform
.. v2betastatus:: RandomShortestSize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
@@ -1380,7 +1380,7 @@ class RandomResize(Transform):
...
@@ -1380,7 +1380,7 @@ class RandomResize(Transform):
output_height = size
output_height = size
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox
es
` etc.)
it can have arbitrary number of leading batch dimensions. For example,
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
torchvision/transforms/v2/_meta.py
View file @
332bff93
...
@@ -15,7 +15,7 @@ class ConvertBoundingBoxFormat(Transform):
...
@@ -15,7 +15,7 @@ class ConvertBoundingBoxFormat(Transform):
string values match the enums, e.g. "XYXY" or "XYWH" etc.
string values match the enums, e.g. "XYXY" or "XYWH" etc.
"""
"""
_transformed_types
=
(
datapoints
.
BoundingBox
,)
_transformed_types
=
(
datapoints
.
BoundingBox
es
,)
def
__init__
(
self
,
format
:
Union
[
str
,
datapoints
.
BoundingBoxFormat
])
->
None
:
def
__init__
(
self
,
format
:
Union
[
str
,
datapoints
.
BoundingBoxFormat
])
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -23,20 +23,20 @@ class ConvertBoundingBoxFormat(Transform):
...
@@ -23,20 +23,20 @@ class ConvertBoundingBoxFormat(Transform):
format
=
datapoints
.
BoundingBoxFormat
[
format
]
format
=
datapoints
.
BoundingBoxFormat
[
format
]
self
.
format
=
format
self
.
format
=
format
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
es
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
es
:
return
F
.
convert_format_bounding_box
(
inpt
,
new_format
=
self
.
format
)
# type: ignore[return-value]
return
F
.
convert_format_bounding_box
es
(
inpt
,
new_format
=
self
.
format
)
# type: ignore[return-value]
class
ClampBoundingBox
(
Transform
):
class
ClampBoundingBox
es
(
Transform
):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
The clamping is done according to the bounding boxes' ``spatial_size`` meta-data.
The clamping is done according to the bounding boxes' ``spatial_size`` meta-data.
.. v2betastatus:: ClampBoundingBox transform
.. v2betastatus:: ClampBoundingBox
es
transform
"""
"""
_transformed_types
=
(
datapoints
.
BoundingBox
,)
_transformed_types
=
(
datapoints
.
BoundingBox
es
,)
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
es
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
es
:
return
F
.
clamp_bounding_box
(
inpt
)
# type: ignore[return-value]
return
F
.
clamp_bounding_box
es
(
inpt
)
# type: ignore[return-value]
torchvision/transforms/v2/_misc.py
View file @
332bff93
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
.utils
import
has_any
,
is_simple_tensor
,
query_bounding_box
from
.utils
import
has_any
,
is_simple_tensor
,
query_bounding_box
es
# TODO: do we want/need to expose this?
# TODO: do we want/need to expose this?
...
@@ -332,16 +332,16 @@ class ConvertImageDtype(Transform):
...
@@ -332,16 +332,16 @@ class ConvertImageDtype(Transform):
return
F
.
to_dtype
(
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
return
F
.
to_dtype
(
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
class
SanitizeBoundingBox
(
Transform
):
class
SanitizeBoundingBox
es
(
Transform
):
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
.. v2betastatus:: SanitizeBoundingBox transform
.. v2betastatus:: SanitizeBoundingBox
es
transform
This transform removes bounding boxes and their associated labels/masks that:
This transform removes bounding boxes and their associated labels/masks that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals.
call :class:`~torchvision.transforms.v2.ClampBoundingBox
es
` first to avoid undesired removals.
It is recommended to call it at the end of a pipeline, before passing the
It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
input to the models. It is critical to call this transform if
...
@@ -384,10 +384,10 @@ class SanitizeBoundingBox(Transform):
...
@@ -384,10 +384,10 @@ class SanitizeBoundingBox(Transform):
)
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
# TODO: this enforces one single BoundingBox entry.
# TODO: this enforces one single BoundingBox
es
entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes
=
query_bounding_box
(
flat_inputs
)
boxes
=
query_bounding_box
es
(
flat_inputs
)
if
boxes
.
ndim
!=
2
:
if
boxes
.
ndim
!=
2
:
raise
ValueError
(
f
"boxes must be of shape (num_boxes, 4), got
{
boxes
.
shape
}
"
)
raise
ValueError
(
f
"boxes must be of shape (num_boxes, 4), got
{
boxes
.
shape
}
"
)
...
@@ -398,8 +398,8 @@ class SanitizeBoundingBox(Transform):
...
@@ -398,8 +398,8 @@ class SanitizeBoundingBox(Transform):
)
)
boxes
=
cast
(
boxes
=
cast
(
datapoints
.
BoundingBox
,
datapoints
.
BoundingBox
es
,
F
.
convert_format_bounding_box
(
F
.
convert_format_bounding_box
es
(
boxes
,
boxes
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
),
),
...
@@ -415,7 +415,7 @@ class SanitizeBoundingBox(Transform):
...
@@ -415,7 +415,7 @@ class SanitizeBoundingBox(Transform):
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
flat_outputs
=
[
flat_outputs
=
[
# Even-though it may look like we're transforming all inputs, we don't:
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
# _transform() will only care about BoundingBoxes
es
and the labels
self
.
_transform
(
inpt
,
params
)
self
.
_transform
(
inpt
,
params
)
for
inpt
in
flat_inputs
for
inpt
in
flat_inputs
]
]
...
@@ -424,9 +424,9 @@ class SanitizeBoundingBox(Transform):
...
@@ -424,9 +424,9 @@ class SanitizeBoundingBox(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
is_label
=
inpt
is
not
None
and
inpt
is
params
[
"labels"
]
is_label
=
inpt
is
not
None
and
inpt
is
params
[
"labels"
]
is_bounding_box_or_mask
=
isinstance
(
inpt
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
))
is_bounding_box
es
_or_mask
=
isinstance
(
inpt
,
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
))
if
not
(
is_label
or
is_bounding_box_or_mask
):
if
not
(
is_label
or
is_bounding_box
es
_or_mask
):
return
inpt
return
inpt
output
=
inpt
[
params
[
"valid"
]]
output
=
inpt
[
params
[
"valid"
]]
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
332bff93
...
@@ -3,8 +3,8 @@ from torchvision.transforms import InterpolationMode # usort: skip
...
@@ -3,8 +3,8 @@ from torchvision.transforms import InterpolationMode # usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_box
,
clamp_bounding_box
es
,
convert_format_bounding_box
,
convert_format_bounding_box
es
,
get_dimensions_image_tensor
,
get_dimensions_image_tensor
,
get_dimensions_image_pil
,
get_dimensions_image_pil
,
get_dimensions
,
get_dimensions
,
...
@@ -15,7 +15,7 @@ from ._meta import (
...
@@ -15,7 +15,7 @@ from ._meta import (
get_num_channels_image_pil
,
get_num_channels_image_pil
,
get_num_channels_video
,
get_num_channels_video
,
get_num_channels
,
get_num_channels
,
get_spatial_size_bounding_box
,
get_spatial_size_bounding_box
es
,
get_spatial_size_image_tensor
,
get_spatial_size_image_tensor
,
get_spatial_size_image_pil
,
get_spatial_size_image_pil
,
get_spatial_size_mask
,
get_spatial_size_mask
,
...
@@ -76,25 +76,25 @@ from ._color import (
...
@@ -76,25 +76,25 @@ from ._color import (
)
)
from
._geometry
import
(
from
._geometry
import
(
affine
,
affine
,
affine_bounding_box
,
affine_bounding_box
es
,
affine_image_pil
,
affine_image_pil
,
affine_image_tensor
,
affine_image_tensor
,
affine_mask
,
affine_mask
,
affine_video
,
affine_video
,
center_crop
,
center_crop
,
center_crop_bounding_box
,
center_crop_bounding_box
es
,
center_crop_image_pil
,
center_crop_image_pil
,
center_crop_image_tensor
,
center_crop_image_tensor
,
center_crop_mask
,
center_crop_mask
,
center_crop_video
,
center_crop_video
,
crop
,
crop
,
crop_bounding_box
,
crop_bounding_box
es
,
crop_image_pil
,
crop_image_pil
,
crop_image_tensor
,
crop_image_tensor
,
crop_mask
,
crop_mask
,
crop_video
,
crop_video
,
elastic
,
elastic
,
elastic_bounding_box
,
elastic_bounding_box
es
,
elastic_image_pil
,
elastic_image_pil
,
elastic_image_tensor
,
elastic_image_tensor
,
elastic_mask
,
elastic_mask
,
...
@@ -106,37 +106,37 @@ from ._geometry import (
...
@@ -106,37 +106,37 @@ from ._geometry import (
five_crop_video
,
five_crop_video
,
hflip
,
# TODO: Consider moving all pure alias definitions at the bottom of the file
hflip
,
# TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip
,
horizontal_flip
,
horizontal_flip_bounding_box
,
horizontal_flip_bounding_box
es
,
horizontal_flip_image_pil
,
horizontal_flip_image_pil
,
horizontal_flip_image_tensor
,
horizontal_flip_image_tensor
,
horizontal_flip_mask
,
horizontal_flip_mask
,
horizontal_flip_video
,
horizontal_flip_video
,
pad
,
pad
,
pad_bounding_box
,
pad_bounding_box
es
,
pad_image_pil
,
pad_image_pil
,
pad_image_tensor
,
pad_image_tensor
,
pad_mask
,
pad_mask
,
pad_video
,
pad_video
,
perspective
,
perspective
,
perspective_bounding_box
,
perspective_bounding_box
es
,
perspective_image_pil
,
perspective_image_pil
,
perspective_image_tensor
,
perspective_image_tensor
,
perspective_mask
,
perspective_mask
,
perspective_video
,
perspective_video
,
resize
,
resize
,
resize_bounding_box
,
resize_bounding_box
es
,
resize_image_pil
,
resize_image_pil
,
resize_image_tensor
,
resize_image_tensor
,
resize_mask
,
resize_mask
,
resize_video
,
resize_video
,
resized_crop
,
resized_crop
,
resized_crop_bounding_box
,
resized_crop_bounding_box
es
,
resized_crop_image_pil
,
resized_crop_image_pil
,
resized_crop_image_tensor
,
resized_crop_image_tensor
,
resized_crop_mask
,
resized_crop_mask
,
resized_crop_video
,
resized_crop_video
,
rotate
,
rotate
,
rotate_bounding_box
,
rotate_bounding_box
es
,
rotate_image_pil
,
rotate_image_pil
,
rotate_image_tensor
,
rotate_image_tensor
,
rotate_mask
,
rotate_mask
,
...
@@ -146,7 +146,7 @@ from ._geometry import (
...
@@ -146,7 +146,7 @@ from ._geometry import (
ten_crop_image_tensor
,
ten_crop_image_tensor
,
ten_crop_video
,
ten_crop_video
,
vertical_flip
,
vertical_flip
,
vertical_flip_bounding_box
,
vertical_flip_bounding_box
es
,
vertical_flip_image_pil
,
vertical_flip_image_pil
,
vertical_flip_image_tensor
,
vertical_flip_image_tensor
,
vertical_flip_mask
,
vertical_flip_mask
,
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
332bff93
...
@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
...
@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
clamp_bounding_box
,
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._meta
import
clamp_bounding_box
es
,
convert_format_bounding_box
es
,
get_spatial_size_image_pil
from
._utils
import
is_simple_tensor
from
._utils
import
is_simple_tensor
...
@@ -51,21 +51,21 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
...
@@ -51,21 +51,21 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return
horizontal_flip_image_tensor
(
mask
)
return
horizontal_flip_image_tensor
(
mask
)
def
horizontal_flip_bounding_box
(
def
horizontal_flip_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
shape
=
bounding_box
.
shape
shape
=
bounding_box
es
.
shape
bounding_box
=
bounding_box
.
clone
().
reshape
(
-
1
,
4
)
bounding_box
es
=
bounding_box
es
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
bounding_box
[:,
[
2
,
0
]]
=
bounding_box
[:,
[
0
,
2
]].
sub_
(
spatial_size
[
1
]).
neg_
()
bounding_box
es
[:,
[
2
,
0
]]
=
bounding_box
es
[:,
[
0
,
2
]].
sub_
(
spatial_size
[
1
]).
neg_
()
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
bounding_box
[:,
0
].
add_
(
bounding_box
[:,
2
]).
sub_
(
spatial_size
[
1
]).
neg_
()
bounding_box
es
[:,
0
].
add_
(
bounding_box
es
[:,
2
]).
sub_
(
spatial_size
[
1
]).
neg_
()
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_box
[:,
0
].
sub_
(
spatial_size
[
1
]).
neg_
()
bounding_box
es
[:,
0
].
sub_
(
spatial_size
[
1
]).
neg_
()
return
bounding_box
.
reshape
(
shape
)
return
bounding_box
es
.
reshape
(
shape
)
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -101,21 +101,21 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
...
@@ -101,21 +101,21 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return
vertical_flip_image_tensor
(
mask
)
return
vertical_flip_image_tensor
(
mask
)
def
vertical_flip_bounding_box
(
def
vertical_flip_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
shape
=
bounding_box
.
shape
shape
=
bounding_box
es
.
shape
bounding_box
=
bounding_box
.
clone
().
reshape
(
-
1
,
4
)
bounding_box
es
=
bounding_box
es
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
:
bounding_box
[:,
[
1
,
3
]]
=
bounding_box
[:,
[
3
,
1
]].
sub_
(
spatial_size
[
0
]).
neg_
()
bounding_box
es
[:,
[
1
,
3
]]
=
bounding_box
es
[:,
[
3
,
1
]].
sub_
(
spatial_size
[
0
]).
neg_
()
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
elif
format
==
datapoints
.
BoundingBoxFormat
.
XYWH
:
bounding_box
[:,
1
].
add_
(
bounding_box
[:,
3
]).
sub_
(
spatial_size
[
0
]).
neg_
()
bounding_box
es
[:,
1
].
add_
(
bounding_box
es
[:,
3
]).
sub_
(
spatial_size
[
0
]).
neg_
()
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
else
:
# format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_box
[:,
1
].
sub_
(
spatial_size
[
0
]).
neg_
()
bounding_box
es
[:,
1
].
sub_
(
spatial_size
[
0
]).
neg_
()
return
bounding_box
.
reshape
(
shape
)
return
bounding_box
es
.
reshape
(
shape
)
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -274,20 +274,20 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
...
@@ -274,20 +274,20 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return
output
return
output
def
resize_bounding_box
(
def
resize_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
spatial_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
bounding_box
es
:
torch
.
Tensor
,
spatial_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
old_height
,
old_width
=
spatial_size
old_height
,
old_width
=
spatial_size
new_height
,
new_width
=
_compute_resized_output_size
(
spatial_size
,
size
=
size
,
max_size
=
max_size
)
new_height
,
new_width
=
_compute_resized_output_size
(
spatial_size
,
size
=
size
,
max_size
=
max_size
)
if
(
new_height
,
new_width
)
==
(
old_height
,
old_width
):
if
(
new_height
,
new_width
)
==
(
old_height
,
old_width
):
return
bounding_box
,
spatial_size
return
bounding_box
es
,
spatial_size
w_ratio
=
new_width
/
old_width
w_ratio
=
new_width
/
old_width
h_ratio
=
new_height
/
old_height
h_ratio
=
new_height
/
old_height
ratios
=
torch
.
tensor
([
w_ratio
,
h_ratio
,
w_ratio
,
h_ratio
],
device
=
bounding_box
.
device
)
ratios
=
torch
.
tensor
([
w_ratio
,
h_ratio
,
w_ratio
,
h_ratio
],
device
=
bounding_box
es
.
device
)
return
(
return
(
bounding_box
.
mul
(
ratios
).
to
(
bounding_box
.
dtype
),
bounding_box
es
.
mul
(
ratios
).
to
(
bounding_box
es
.
dtype
),
(
new_height
,
new_width
),
(
new_height
,
new_width
),
)
)
...
@@ -650,8 +650,8 @@ def affine_image_pil(
...
@@ -650,8 +650,8 @@ def affine_image_pil(
return
_FP
.
affine
(
image
,
matrix
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
return
_FP
.
affine
(
image
,
matrix
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
def
_affine_bounding_box_with_expand
(
def
_affine_bounding_box
es
_with_expand
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
...
@@ -661,17 +661,17 @@ def _affine_bounding_box_with_expand(
...
@@ -661,17 +661,17 @@ def _affine_bounding_box_with_expand(
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
if
bounding_box
.
numel
()
==
0
:
if
bounding_box
es
.
numel
()
==
0
:
return
bounding_box
,
spatial_size
return
bounding_box
es
,
spatial_size
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
es
.
shape
original_dtype
=
bounding_box
.
dtype
original_dtype
=
bounding_box
es
.
dtype
bounding_box
=
bounding_box
.
clone
()
if
bounding_box
.
is_floating_point
()
else
bounding_box
.
float
()
bounding_box
es
=
bounding_box
es
.
clone
()
if
bounding_box
es
.
is_floating_point
()
else
bounding_box
es
.
float
()
dtype
=
bounding_box
.
dtype
dtype
=
bounding_box
es
.
dtype
device
=
bounding_box
.
device
device
=
bounding_box
es
.
device
bounding_box
=
(
bounding_box
es
=
(
convert_format_bounding_box
(
convert_format_bounding_box
es
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bounding_box
es
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
)
).
reshape
(
-
1
,
4
)
).
reshape
(
-
1
,
4
)
...
@@ -697,7 +697,7 @@ def _affine_bounding_box_with_expand(
...
@@ -697,7 +697,7 @@ def _affine_bounding_box_with_expand(
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points
=
bounding_box
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
bounding_box
es
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
device
,
dtype
=
dtype
)],
dim
=-
1
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
device
,
dtype
=
dtype
)],
dim
=-
1
)
# 2) Now let's transform the points using affine matrix
# 2) Now let's transform the points using affine matrix
transformed_points
=
torch
.
matmul
(
points
,
transposed_affine_matrix
)
transformed_points
=
torch
.
matmul
(
points
,
transposed_affine_matrix
)
...
@@ -730,8 +730,8 @@ def _affine_bounding_box_with_expand(
...
@@ -730,8 +730,8 @@ def _affine_bounding_box_with_expand(
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
spatial_size
=
(
new_height
,
new_width
)
spatial_size
=
(
new_height
,
new_width
)
out_bboxes
=
clamp_bounding_box
(
out_bboxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
)
out_bboxes
=
clamp_bounding_box
es
(
out_bboxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
)
out_bboxes
=
convert_format_bounding_box
(
out_bboxes
=
convert_format_bounding_box
es
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
).
reshape
(
original_shape
)
...
@@ -739,8 +739,8 @@ def _affine_bounding_box_with_expand(
...
@@ -739,8 +739,8 @@ def _affine_bounding_box_with_expand(
return
out_bboxes
,
spatial_size
return
out_bboxes
,
spatial_size
def
affine_bounding_box
(
def
affine_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
...
@@ -749,8 +749,8 @@ def affine_bounding_box(
...
@@ -749,8 +749,8 @@ def affine_bounding_box(
shear
:
List
[
float
],
shear
:
List
[
float
],
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
out_box
,
_
=
_affine_bounding_box_with_expand
(
out_box
,
_
=
_affine_bounding_box
es
_with_expand
(
bounding_box
,
bounding_box
es
,
format
=
format
,
format
=
format
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
angle
=
angle
,
angle
=
angle
,
...
@@ -927,8 +927,8 @@ def rotate_image_pil(
...
@@ -927,8 +927,8 @@ def rotate_image_pil(
)
)
def
rotate_bounding_box
(
def
rotate_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
angle
:
float
,
angle
:
float
,
...
@@ -938,8 +938,8 @@ def rotate_bounding_box(
...
@@ -938,8 +938,8 @@ def rotate_bounding_box(
if
center
is
not
None
and
expand
:
if
center
is
not
None
and
expand
:
warnings
.
warn
(
"The provided center argument has no effect on the result if expand is True"
)
warnings
.
warn
(
"The provided center argument has no effect on the result if expand is True"
)
return
_affine_bounding_box_with_expand
(
return
_affine_bounding_box
es
_with_expand
(
bounding_box
,
bounding_box
es
,
format
=
format
,
format
=
format
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
angle
=-
angle
,
angle
=-
angle
,
...
@@ -1165,8 +1165,8 @@ def pad_mask(
...
@@ -1165,8 +1165,8 @@ def pad_mask(
return
output
return
output
def
pad_bounding_box
(
def
pad_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
padding
:
List
[
int
],
padding
:
List
[
int
],
...
@@ -1182,14 +1182,14 @@ def pad_bounding_box(
...
@@ -1182,14 +1182,14 @@ def pad_bounding_box(
pad
=
[
left
,
top
,
left
,
top
]
pad
=
[
left
,
top
,
left
,
top
]
else
:
else
:
pad
=
[
left
,
top
,
0
,
0
]
pad
=
[
left
,
top
,
0
,
0
]
bounding_box
=
bounding_box
+
torch
.
tensor
(
pad
,
dtype
=
bounding_box
.
dtype
,
device
=
bounding_box
.
device
)
bounding_box
es
=
bounding_box
es
+
torch
.
tensor
(
pad
,
dtype
=
bounding_box
es
.
dtype
,
device
=
bounding_box
es
.
device
)
height
,
width
=
spatial_size
height
,
width
=
spatial_size
height
+=
top
+
bottom
height
+=
top
+
bottom
width
+=
left
+
right
width
+=
left
+
right
spatial_size
=
(
height
,
width
)
spatial_size
=
(
height
,
width
)
return
clamp_bounding_box
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
return
clamp_bounding_box
es
(
bounding_box
es
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
def
pad_video
(
def
pad_video
(
...
@@ -1245,8 +1245,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
...
@@ -1245,8 +1245,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
crop_image_pil
=
_FP
.
crop
crop_image_pil
=
_FP
.
crop
def
crop_bounding_box
(
def
crop_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
...
@@ -1260,10 +1260,10 @@ def crop_bounding_box(
...
@@ -1260,10 +1260,10 @@ def crop_bounding_box(
else
:
else
:
sub
=
[
left
,
top
,
0
,
0
]
sub
=
[
left
,
top
,
0
,
0
]
bounding_box
=
bounding_box
-
torch
.
tensor
(
sub
,
dtype
=
bounding_box
.
dtype
,
device
=
bounding_box
.
device
)
bounding_box
es
=
bounding_box
es
-
torch
.
tensor
(
sub
,
dtype
=
bounding_box
es
.
dtype
,
device
=
bounding_box
es
.
device
)
spatial_size
=
(
height
,
width
)
spatial_size
=
(
height
,
width
)
return
clamp_bounding_box
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
return
clamp_bounding_box
es
(
bounding_box
es
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
...
@@ -1409,27 +1409,27 @@ def perspective_image_pil(
...
@@ -1409,27 +1409,27 @@ def perspective_image_pil(
return
_FP
.
perspective
(
image
,
perspective_coeffs
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
return
_FP
.
perspective
(
image
,
perspective_coeffs
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
def
perspective_bounding_box
(
def
perspective_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
bounding_box
.
numel
()
==
0
:
if
bounding_box
es
.
numel
()
==
0
:
return
bounding_box
return
bounding_box
es
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
es
.
shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
es
bounding_box
=
(
bounding_box
es
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
convert_format_bounding_box
es
(
bounding_box
es
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
).
reshape
(
-
1
,
4
)
dtype
=
bounding_box
.
dtype
if
torch
.
is_floating_point
(
bounding_box
)
else
torch
.
float32
dtype
=
bounding_box
es
.
dtype
if
torch
.
is_floating_point
(
bounding_box
es
)
else
torch
.
float32
device
=
bounding_box
.
device
device
=
bounding_box
es
.
device
# perspective_coeffs are computed as endpoint -> start point
# perspective_coeffs are computed as endpoint -> start point
# We have to invert perspective_coeffs for bboxes:
# We have to invert perspective_coeffs for bboxes:
...
@@ -1475,7 +1475,7 @@ def perspective_bounding_box(
...
@@ -1475,7 +1475,7 @@ def perspective_bounding_box(
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points
=
bounding_box
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
bounding_box
es
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
points
.
device
)],
dim
=-
1
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
points
.
device
)],
dim
=-
1
)
# 2) Now let's transform the points using perspective matrices
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
...
@@ -1490,15 +1490,15 @@ def perspective_bounding_box(
...
@@ -1490,15 +1490,15 @@ def perspective_bounding_box(
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bboxes
=
clamp_bounding_box
(
out_bboxes
=
clamp_bounding_box
es
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
),
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
es
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
)
)
# out_bboxes should be of shape [N boxes, 4]
# out_bboxes should be of shape [N boxes, 4]
return
convert_format_bounding_box
(
return
convert_format_bounding_box
es
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
).
reshape
(
original_shape
)
...
@@ -1648,26 +1648,26 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
...
@@ -1648,26 +1648,26 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
return
base_grid
return
base_grid
def
elastic_bounding_box
(
def
elastic_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
bounding_box
.
numel
()
==
0
:
if
bounding_box
es
.
numel
()
==
0
:
return
bounding_box
return
bounding_box
es
# TODO: add in docstring about approximation we are doing for grid inversion
# TODO: add in docstring about approximation we are doing for grid inversion
device
=
bounding_box
.
device
device
=
bounding_box
es
.
device
dtype
=
bounding_box
.
dtype
if
torch
.
is_floating_point
(
bounding_box
)
else
torch
.
float32
dtype
=
bounding_box
es
.
dtype
if
torch
.
is_floating_point
(
bounding_box
es
)
else
torch
.
float32
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
es
.
shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
es
bounding_box
=
(
bounding_box
es
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
convert_format_bounding_box
es
(
bounding_box
es
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
).
reshape
(
-
1
,
4
)
id_grid
=
_create_identity_grid
(
spatial_size
,
device
=
device
,
dtype
=
dtype
)
id_grid
=
_create_identity_grid
(
spatial_size
,
device
=
device
,
dtype
=
dtype
)
...
@@ -1676,7 +1676,7 @@ def elastic_bounding_box(
...
@@ -1676,7 +1676,7 @@ def elastic_bounding_box(
inv_grid
=
id_grid
.
sub_
(
displacement
)
inv_grid
=
id_grid
.
sub_
(
displacement
)
# Get points from bboxes
# Get points from bboxes
points
=
bounding_box
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
bounding_box
es
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
if
points
.
is_floating_point
():
if
points
.
is_floating_point
():
points
=
points
.
ceil_
()
points
=
points
.
ceil_
()
index_xy
=
points
.
to
(
dtype
=
torch
.
long
)
index_xy
=
points
.
to
(
dtype
=
torch
.
long
)
...
@@ -1688,13 +1688,13 @@ def elastic_bounding_box(
...
@@ -1688,13 +1688,13 @@ def elastic_bounding_box(
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bboxes
=
clamp_bounding_box
(
out_bboxes
=
clamp_bounding_box
es
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
),
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
es
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
)
)
return
convert_format_bounding_box
(
return
convert_format_bounding_box
es
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
).
reshape
(
original_shape
)
...
@@ -1818,15 +1818,17 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
...
@@ -1818,15 +1818,17 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
return
crop_image_pil
(
image
,
crop_top
,
crop_left
,
crop_height
,
crop_width
)
return
crop_image_pil
(
image
,
crop_top
,
crop_left
,
crop_height
,
crop_width
)
def
center_crop_bounding_box
(
def
center_crop_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
output_size
:
List
[
int
],
output_size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
*
spatial_size
)
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
*
spatial_size
)
return
crop_bounding_box
(
bounding_box
,
format
,
top
=
crop_top
,
left
=
crop_left
,
height
=
crop_height
,
width
=
crop_width
)
return
crop_bounding_boxes
(
bounding_boxes
,
format
,
top
=
crop_top
,
left
=
crop_left
,
height
=
crop_height
,
width
=
crop_width
)
def
center_crop_mask
(
mask
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
def
center_crop_mask
(
mask
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
...
@@ -1893,8 +1895,8 @@ def resized_crop_image_pil(
...
@@ -1893,8 +1895,8 @@ def resized_crop_image_pil(
return
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
return
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
def
resized_crop_bounding_box
(
def
resized_crop_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
bounding_box
es
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
...
@@ -1902,8 +1904,8 @@ def resized_crop_bounding_box(
...
@@ -1902,8 +1904,8 @@ def resized_crop_bounding_box(
width
:
int
,
width
:
int
,
size
:
List
[
int
],
size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
bounding_box
,
_
=
crop_bounding_box
(
bounding_box
,
format
,
top
,
left
,
height
,
width
)
bounding_box
es
,
_
=
crop_bounding_box
es
(
bounding_box
es
,
format
,
top
,
left
,
height
,
width
)
return
resize_bounding_box
(
bounding_box
,
spatial_size
=
(
height
,
width
),
size
=
size
)
return
resize_bounding_box
es
(
bounding_box
es
,
spatial_size
=
(
height
,
width
),
size
=
size
)
def
resized_crop_mask
(
def
resized_crop_mask
(
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
332bff93
...
@@ -109,8 +109,8 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
...
@@ -109,8 +109,8 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
get_spatial_size_bounding_box
(
bounding_box
:
datapoints
.
BoundingBox
)
->
List
[
int
]:
def
get_spatial_size_bounding_box
es
(
bounding_box
es
:
datapoints
.
BoundingBox
es
)
->
List
[
int
]:
return
list
(
bounding_box
.
spatial_size
)
return
list
(
bounding_box
es
.
spatial_size
)
def
get_spatial_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
def
get_spatial_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
...
@@ -119,7 +119,7 @@ def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
...
@@ -119,7 +119,7 @@ def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_spatial_size_image_tensor
(
inpt
)
return
get_spatial_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
)):
return
list
(
inpt
.
spatial_size
)
return
list
(
inpt
.
spatial_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_spatial_size_image_pil
(
inpt
)
return
get_spatial_size_image_pil
(
inpt
)
...
@@ -185,95 +185,97 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
...
@@ -185,95 +185,97 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
return
xyxy
return
xyxy
def
_convert_format_bounding_box
(
def
_convert_format_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
bounding_box
es
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
new_format
==
old_format
:
if
new_format
==
old_format
:
return
bounding_box
return
bounding_box
es
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
if
old_format
==
BoundingBoxFormat
.
XYWH
:
if
old_format
==
BoundingBoxFormat
.
XYWH
:
bounding_box
=
_xywh_to_xyxy
(
bounding_box
,
inplace
)
bounding_box
es
=
_xywh_to_xyxy
(
bounding_box
es
,
inplace
)
elif
old_format
==
BoundingBoxFormat
.
CXCYWH
:
elif
old_format
==
BoundingBoxFormat
.
CXCYWH
:
bounding_box
=
_cxcywh_to_xyxy
(
bounding_box
,
inplace
)
bounding_box
es
=
_cxcywh_to_xyxy
(
bounding_box
es
,
inplace
)
if
new_format
==
BoundingBoxFormat
.
XYWH
:
if
new_format
==
BoundingBoxFormat
.
XYWH
:
bounding_box
=
_xyxy_to_xywh
(
bounding_box
,
inplace
)
bounding_box
es
=
_xyxy_to_xywh
(
bounding_box
es
,
inplace
)
elif
new_format
==
BoundingBoxFormat
.
CXCYWH
:
elif
new_format
==
BoundingBoxFormat
.
CXCYWH
:
bounding_box
=
_xyxy_to_cxcywh
(
bounding_box
,
inplace
)
bounding_box
es
=
_xyxy_to_cxcywh
(
bounding_box
es
,
inplace
)
return
bounding_box
return
bounding_box
es
def
convert_format_bounding_box
(
def
convert_format_bounding_box
es
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
datapoints
.
_InputTypeJIT
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
datapoints
.
_InputTypeJIT
:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# inputs as well as extract it from `datapoints.BoundingBox
es
` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
# default error that would be thrown if `new_format` had no default value.
if
new_format
is
None
:
if
new_format
is
None
:
raise
TypeError
(
"convert_format_bounding_box() missing 1 required argument: 'new_format'"
)
raise
TypeError
(
"convert_format_bounding_box
es
() missing 1 required argument: 'new_format'"
)
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
_log_api_usage_once
(
convert_format_bounding_box
es
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
old_format
is
None
:
if
old_format
is
None
:
raise
ValueError
(
"For simple tensor inputs, `old_format` has to be passed."
)
raise
ValueError
(
"For simple tensor inputs, `old_format` has to be passed."
)
return
_convert_format_bounding_box
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
return
_convert_format_bounding_box
es
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
es
):
if
old_format
is
not
None
:
if
old_format
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `old_format` must not be passed."
)
raise
ValueError
(
"For bounding box datapoint inputs, `old_format` must not be passed."
)
output
=
_convert_format_bounding_box
(
output
=
_convert_format_bounding_box
es
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_format
=
inpt
.
format
,
new_format
=
new_format
,
inplace
=
inplace
inpt
.
as_subclass
(
torch
.
Tensor
),
old_format
=
inpt
.
format
,
new_format
=
new_format
,
inplace
=
inplace
)
)
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
,
format
=
new_format
)
return
datapoints
.
BoundingBox
es
.
wrap_like
(
inpt
,
output
,
format
=
new_format
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
)
)
def
_clamp_bounding_box
(
def
_clamp_bounding_box
es
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
es
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
# BoundingBoxFormat instead of converting back and forth
in_dtype
=
bounding_box
.
dtype
in_dtype
=
bounding_box
es
.
dtype
bounding_box
=
bounding_box
.
clone
()
if
bounding_box
.
is_floating_point
()
else
bounding_box
.
float
()
bounding_box
es
=
bounding_box
es
.
clone
()
if
bounding_box
es
.
is_floating_point
()
else
bounding_box
es
.
float
()
xyxy_boxes
=
convert_format_bounding_box
(
xyxy_boxes
=
convert_format_bounding_box
es
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bounding_box
es
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
)
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
1
])
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
1
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
0
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
0
])
out_boxes
=
convert_format_bounding_box
(
out_boxes
=
convert_format_bounding_box
es
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
)
return
out_boxes
.
to
(
in_dtype
)
return
out_boxes
.
to
(
in_dtype
)
def
clamp_bounding_box
(
def
clamp_bounding_box
es
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
datapoints
.
_InputTypeJIT
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
spatial_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
spatial_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
_log_api_usage_once
(
clamp_bounding_box
es
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
format
is
None
or
spatial_size
is
None
:
if
format
is
None
or
spatial_size
is
None
:
raise
ValueError
(
"For simple tensor inputs, `format` and `spatial_size` has to be passed."
)
raise
ValueError
(
"For simple tensor inputs, `format` and `spatial_size` has to be passed."
)
return
_clamp_bounding_box
(
inpt
,
format
=
format
,
spatial_size
=
spatial_size
)
return
_clamp_bounding_box
es
(
inpt
,
format
=
format
,
spatial_size
=
spatial_size
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
es
):
if
format
is
not
None
or
spatial_size
is
not
None
:
if
format
is
not
None
or
spatial_size
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `format` and `spatial_size` must not be passed."
)
raise
ValueError
(
"For bounding box datapoint inputs, `format` and `spatial_size` must not be passed."
)
output
=
_clamp_bounding_box
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
)
output
=
_clamp_bounding_boxes
(
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
)
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
...
...
torchvision/transforms/v2/utils.py
View file @
332bff93
...
@@ -9,8 +9,8 @@ from torchvision._utils import sequence_to_str
...
@@ -9,8 +9,8 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_spatial_size
,
is_simple_tensor
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_spatial_size
,
is_simple_tensor
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
def
query_bounding_box
es
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
es
:
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBox
)]
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBox
es
)]
if
not
bounding_boxes
:
if
not
bounding_boxes
:
raise
TypeError
(
"No bounding box was found in the sample"
)
raise
TypeError
(
"No bounding box was found in the sample"
)
elif
len
(
bounding_boxes
)
>
1
:
elif
len
(
bounding_boxes
)
>
1
:
...
@@ -37,7 +37,7 @@ def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
...
@@ -37,7 +37,7 @@ def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
tuple
(
get_spatial_size
(
inpt
))
tuple
(
get_spatial_size
(
inpt
))
for
inpt
in
flat_inputs
for
inpt
in
flat_inputs
if
isinstance
(
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBox
)
inpt
,
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBox
es
)
)
)
or
is_simple_tensor
(
inpt
)
or
is_simple_tensor
(
inpt
)
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment