Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
26ed129d
"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "36fbeee341e283a93b6befa2a4d9085b7a5dd2b1"
Unverified
Commit
26ed129d
authored
Aug 22, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 22, 2023
Browse files
Make v2.utils private. (#7863)
parent
9c4f7389
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
115 additions
and
107 deletions
+115
-107
references/segmentation/v2_extras.py
references/segmentation/v2_extras.py
+1
-1
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+1
-1
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+1
-1
test/test_transforms_v2.py
test/test_transforms_v2.py
+1
-1
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-2
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+1
-1
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+5
-5
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+1
-1
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+10
-2
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+1
-1
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-1
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+1
-2
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+1
-2
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+1
-1
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+5
-1
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-2
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+1
-1
torchvision/transforms/v2/_type_conversion.py
torchvision/transforms/v2/_type_conversion.py
+1
-1
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+80
-1
torchvision/transforms/v2/utils.py
torchvision/transforms/v2/utils.py
+0
-79
No files found.
references/segmentation/v2_extras.py
View file @
26ed129d
...
@@ -11,7 +11,7 @@ class PadIfSmaller(v2.Transform):
...
@@ -11,7 +11,7 @@ class PadIfSmaller(v2.Transform):
self
.
fill
=
v2
.
_utils
.
_setup_fill_arg
(
fill
)
self
.
fill
=
v2
.
_utils
.
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
sample
):
def
_get_params
(
self
,
sample
):
_
,
height
,
width
=
v2
.
utils
.
query_chw
(
sample
)
_
,
height
,
width
=
v2
.
_
utils
.
query_chw
(
sample
)
padding
=
[
0
,
0
,
max
(
self
.
size
-
width
,
0
),
max
(
self
.
size
-
height
,
0
)]
padding
=
[
0
,
0
,
max
(
self
.
size
-
width
,
0
),
max
(
self
.
size
-
height
,
0
)]
needs_padding
=
any
(
padding
)
needs_padding
=
any
(
padding
)
return
dict
(
padding
=
padding
,
needs_padding
=
needs_padding
)
return
dict
(
padding
=
padding
,
needs_padding
=
needs_padding
)
...
...
test/test_prototype_datasets_builtin.py
View file @
26ed129d
...
@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
...
@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datapoints
import
Label
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.transforms.v2.utils
import
is_pure_tensor
from
torchvision.transforms.v2.
_
utils
import
is_pure_tensor
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
...
...
test/test_prototype_transforms.py
View file @
26ed129d
...
@@ -10,8 +10,8 @@ from prototype_common_utils import make_label
...
@@ -10,8 +10,8 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.v2.utils
import
check_type
,
is_pure_tensor
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
DEFAULT_EXTRA_DIMS
,
DEFAULT_EXTRA_DIMS
,
make_bounding_boxes
,
make_bounding_boxes
,
...
...
test/test_transforms_v2.py
View file @
26ed129d
...
@@ -16,7 +16,7 @@ from torchvision import datapoints
...
@@ -16,7 +16,7 @@ from torchvision import datapoints
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.utils
import
check_type
,
is_pure_tensor
,
query_chw
from
torchvision.transforms.v2.
_
utils
import
check_type
,
is_pure_tensor
,
query_chw
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
make_bounding_boxes
,
make_bounding_boxes
,
make_detection_mask
,
make_detection_mask
,
...
...
test/test_transforms_v2_consistency.py
View file @
26ed129d
...
@@ -19,9 +19,8 @@ from torchvision._utils import sequence_to_str
...
@@ -19,9 +19,8 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2._utils
import
_get_fill
,
query_size
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.utils
import
query_size
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
ArgsKwargs
,
ArgsKwargs
,
make_bounding_boxes
,
make_bounding_boxes
,
...
...
test/test_transforms_v2_functional.py
View file @
26ed129d
...
@@ -13,9 +13,9 @@ from torch.utils._pytree import tree_map
...
@@ -13,9 +13,9 @@ from torch.utils._pytree import tree_map
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2._utils
import
is_pure_tensor
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_bounding_box_format
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_bounding_box_format
from
torchvision.transforms.v2.utils
import
is_pure_tensor
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
from
transforms_v2_legacy_utils
import
(
from
transforms_v2_legacy_utils
import
(
...
...
test/test_transforms_v2_utils.py
View file @
26ed129d
...
@@ -3,12 +3,12 @@ import pytest
...
@@ -3,12 +3,12 @@ import pytest
import
torch
import
torch
import
torchvision.transforms.v2.utils
import
torchvision.transforms.v2.
_
utils
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_boxes
,
make_detection_mask
,
make_image
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2._utils
import
has_all
,
has_any
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.functional
import
to_pil_image
from
torchvision.transforms.v2.utils
import
has_all
,
has_any
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
IMAGE
=
make_image
(
DEFAULT_SIZE
,
color_space
=
"RGB"
)
...
@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
...
@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoints
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoints
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_pure_tensor
),
True
),
((
IMAGE
,),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_
utils
.
is_pure_tensor
),
True
),
(
(
(
torch
.
Tensor
(
IMAGE
),),
(
torch
.
Tensor
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_pure_tensor
),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_
utils
.
is_pure_tensor
),
True
,
True
,
),
),
(
(
(
to_pil_image
(
IMAGE
),),
(
to_pil_image
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_pure_tensor
),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
_
utils
.
is_pure_tensor
),
True
,
True
,
),
),
],
],
...
...
torchvision/prototype/transforms/_augment.py
View file @
26ed129d
...
@@ -7,9 +7,9 @@ from torchvision import datapoints
...
@@ -7,9 +7,9 @@ from torchvision import datapoints
from
torchvision.ops
import
masks_to_boxes
from
torchvision.ops
import
masks_to_boxes
from
torchvision.prototype
import
datapoints
as
proto_datapoints
from
torchvision.prototype
import
datapoints
as
proto_datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2._utils
import
is_pure_tensor
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.utils
import
is_pure_tensor
class
SimpleCopyPaste
(
Transform
):
class
SimpleCopyPaste
(
Transform
):
...
...
torchvision/prototype/transforms/_geometry.py
View file @
26ed129d
...
@@ -6,8 +6,16 @@ import torch
...
@@ -6,8 +6,16 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
(
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_pure_tensor
,
query_size
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
,
get_bounding_boxes
,
has_any
,
is_pure_tensor
,
query_size
,
)
class
FixedSizeCrop
(
Transform
):
class
FixedSizeCrop
(
Transform
):
...
...
torchvision/prototype/transforms/_misc.py
View file @
26ed129d
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2.utils
import
is_pure_tensor
from
torchvision.transforms.v2.
_
utils
import
is_pure_tensor
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
...
torchvision/transforms/v2/__init__.py
View file @
26ed129d
from
torchvision.transforms
import
AutoAugmentPolicy
,
InterpolationMode
# usort: skip
from
torchvision.transforms
import
AutoAugmentPolicy
,
InterpolationMode
# usort: skip
from
.
import
functional
,
utils
# usort: skip
from
.
import
functional
# usort: skip
from
._transform
import
Transform
# usort: skip
from
._transform
import
Transform
# usort: skip
...
...
torchvision/transforms/v2/_augment.py
View file @
26ed129d
...
@@ -11,8 +11,7 @@ from torchvision import datapoints, transforms as _transforms
...
@@ -11,8 +11,7 @@ from torchvision import datapoints, transforms as _transforms
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
._transform
import
_RandomApplyTransform
,
Transform
from
._transform
import
_RandomApplyTransform
,
Transform
from
._utils
import
_parse_labels_getter
from
._utils
import
_parse_labels_getter
,
has_any
,
is_pure_tensor
,
query_chw
,
query_size
from
.utils
import
has_any
,
is_pure_tensor
,
query_chw
,
query_size
class
RandomErasing
(
_RandomApplyTransform
):
class
RandomErasing
(
_RandomApplyTransform
):
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
26ed129d
...
@@ -12,8 +12,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
...
@@ -12,8 +12,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
from
._utils
import
_get_fill
,
_setup_fill_arg
,
check_type
,
is_pure_tensor
from
.utils
import
check_type
,
is_pure_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
...
...
torchvision/transforms/v2/_color.py
View file @
26ed129d
...
@@ -6,7 +6,7 @@ from torchvision import transforms as _transforms
...
@@ -6,7 +6,7 @@ from torchvision import transforms as _transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
.utils
import
query_chw
from
.
_
utils
import
query_chw
class
Grayscale
(
Transform
):
class
Grayscale
(
Transform
):
...
...
torchvision/transforms/v2/_geometry.py
View file @
26ed129d
...
@@ -23,8 +23,12 @@ from ._utils import (
...
@@ -23,8 +23,12 @@ from ._utils import (
_setup_fill_arg
,
_setup_fill_arg
,
_setup_float_or_seq
,
_setup_float_or_seq
,
_setup_size
,
_setup_size
,
get_bounding_boxes
,
has_all
,
has_any
,
is_pure_tensor
,
query_size
,
)
)
from
.utils
import
get_bounding_boxes
,
has_all
,
has_any
,
is_pure_tensor
,
query_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
...
torchvision/transforms/v2/_misc.py
View file @
26ed129d
...
@@ -9,8 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
...
@@ -9,8 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
,
get_bounding_boxes
,
has_any
,
is_pure_tensor
from
.utils
import
get_bounding_boxes
,
has_any
,
is_pure_tensor
# TODO: do we want/need to expose this?
# TODO: do we want/need to expose this?
...
...
torchvision/transforms/v2/_transform.py
View file @
26ed129d
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_pure_tensor
from
torchvision.transforms.v2.
_
utils
import
check_type
,
has_any
,
is_pure_tensor
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
.functional._utils
import
_get_kernel
from
.functional._utils
import
_get_kernel
...
...
torchvision/transforms/v2/_type_conversion.py
View file @
26ed129d
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2.utils
import
is_pure_tensor
from
torchvision.transforms.v2.
_
utils
import
is_pure_tensor
class
PILToTensor
(
Transform
):
class
PILToTensor
(
Transform
):
...
...
torchvision/transforms/v2/_utils.py
View file @
26ed129d
from
__future__
import
annotations
import
collections.abc
import
collections.abc
import
numbers
import
numbers
from
contextlib
import
suppress
from
contextlib
import
suppress
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_pure_tensor
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
...
@@ -138,3 +147,73 @@ def _parse_labels_getter(
...
@@ -138,3 +147,73 @@ def _parse_labels_getter(
return
lambda
_
:
None
return
lambda
_
:
None
else
:
else
:
raise
ValueError
(
f
"labels_getter should either be 'default', a callable, or None, but got
{
labels_getter
}
."
)
raise
ValueError
(
f
"labels_getter should either be 'default', a callable, or None, but got
{
labels_getter
}
."
)
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
# This assumes there is only one bbox per sample as per the general convention
try
:
return
next
(
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
))
except
StopIteration
:
raise
ValueError
(
"No bounding boxes were found in the sample"
)
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_pure_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
elif
len
(
chws
)
>
1
:
raise
ValueError
(
f
"Found multiple CxHxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
chws
))
}
"
)
c
,
h
,
w
=
chws
.
pop
()
return
c
,
h
,
w
def
query_size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
sizes
=
{
tuple
(
get_size
(
inpt
))
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_pure_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
),
)
}
if
not
sizes
:
raise
TypeError
(
"No image, video, mask or bounding box was found in the sample"
)
elif
len
(
sizes
)
>
1
:
raise
ValueError
(
f
"Found multiple HxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
sizes
))
}
"
)
h
,
w
=
sizes
.
pop
()
return
h
,
w
def
check_type
(
obj
:
Any
,
types_or_checks
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...])
->
bool
:
for
type_or_check
in
types_or_checks
:
if
isinstance
(
obj
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
obj
):
return
True
return
False
def
has_any
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
inpt
in
flat_inputs
:
if
check_type
(
inpt
,
types_or_checks
):
return
True
return
False
def
has_all
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
type_or_check
in
types_or_checks
:
for
inpt
in
flat_inputs
:
if
isinstance
(
inpt
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
inpt
):
break
else
:
return
False
return
True
torchvision/transforms/v2/utils.py
deleted
100644 → 0
View file @
9c4f7389
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
import
PIL.Image
from
torchvision
import
datapoints
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_pure_tensor
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
# This assumes there is only one bbox per sample as per the general convention
try
:
return
next
(
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
))
except
StopIteration
:
raise
ValueError
(
"No bounding boxes were found in the sample"
)
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_pure_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
))
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
elif
len
(
chws
)
>
1
:
raise
ValueError
(
f
"Found multiple CxHxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
chws
))
}
"
)
c
,
h
,
w
=
chws
.
pop
()
return
c
,
h
,
w
def
query_size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
sizes
=
{
tuple
(
get_size
(
inpt
))
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_pure_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
),
)
}
if
not
sizes
:
raise
TypeError
(
"No image, video, mask or bounding box was found in the sample"
)
elif
len
(
sizes
)
>
1
:
raise
ValueError
(
f
"Found multiple HxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
sizes
))
}
"
)
h
,
w
=
sizes
.
pop
()
return
h
,
w
def
check_type
(
obj
:
Any
,
types_or_checks
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...])
->
bool
:
for
type_or_check
in
types_or_checks
:
if
isinstance
(
obj
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
obj
):
return
True
return
False
def
has_any
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
inpt
in
flat_inputs
:
if
check_type
(
inpt
,
types_or_checks
):
return
True
return
False
def
has_all
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
type_or_check
in
types_or_checks
:
for
inpt
in
flat_inputs
:
if
isinstance
(
inpt
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
inpt
):
break
else
:
return
False
return
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment