Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4941c6b6
Unverified
Commit
4941c6b6
authored
Nov 30, 2022
by
Philip Meier
Committed by
GitHub
Nov 30, 2022
Browse files
expose some prototype transforms utils (#6989)
* expose some prototype transforms utils * rename _isinstance
parent
b94f176a
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
88 additions
and
88 deletions
+88
-88
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+3
-3
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+1
-1
test/test_prototype_transforms_utils.py
test/test_prototype_transforms_utils.py
+1
-1
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+1
-1
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+3
-2
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+1
-1
torchvision/prototype/transforms/_deprecated.py
torchvision/prototype/transforms/_deprecated.py
+1
-1
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+1
-4
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+2
-1
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+3
-5
torchvision/prototype/transforms/_utils.py
torchvision/prototype/transforms/_utils.py
+1
-67
torchvision/prototype/transforms/utils.py
torchvision/prototype/transforms/utils.py
+69
-0
No files found.
test/test_prototype_transforms.py
View file @
4941c6b6
...
...
@@ -23,7 +23,7 @@ from prototype_common_utils import (
)
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
features
,
transforms
from
torchvision.prototype.transforms.
_
utils
import
_isinstanc
e
from
torchvision.prototype.transforms.utils
import
check_typ
e
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
...
@@ -1860,7 +1860,7 @@ def test_permute_dimensions(dims, inverse_dims):
value_type
=
type
(
value
)
transformed_value
=
transformed_sample
[
key
]
if
_isinstanc
e
(
value
,
(
features
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
if
check_typ
e
(
value
,
(
features
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
if
transform
.
dims
.
get
(
value_type
)
is
not
None
:
assert
transformed_value
.
permute
(
inverse_dims
[
value_type
]).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
...
...
@@ -1893,7 +1893,7 @@ def test_transpose_dimensions(dims):
transformed_value
=
transformed_sample
[
key
]
transposed_dims
=
transform
.
dims
.
get
(
value_type
)
if
_isinstanc
e
(
value
,
(
features
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
if
check_typ
e
(
value
,
(
features
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
if
transposed_dims
is
not
None
:
assert
transformed_value
.
transpose
(
*
transposed_dims
).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
...
...
test/test_prototype_transforms_consistency.py
View file @
4941c6b6
...
...
@@ -26,8 +26,8 @@ from torchvision import transforms as legacy_transforms
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
features
,
transforms
as
prototype_transforms
from
torchvision.prototype.transforms
import
functional
as
prototype_F
from
torchvision.prototype.transforms._utils
import
query_spatial_size
from
torchvision.prototype.transforms.functional
import
to_image_pil
from
torchvision.prototype.transforms.utils
import
query_spatial_size
from
torchvision.transforms
import
functional
as
legacy_F
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
features
.
ColorSpace
.
RGB
],
extra_dims
=
[(
4
,)])
...
...
test/test_prototype_transforms_utils.py
View file @
4941c6b6
...
...
@@ -6,8 +6,8 @@ import torch
from
prototype_common_utils
import
make_bounding_box
,
make_detection_mask
,
make_image
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms._utils
import
has_all
,
has_any
from
torchvision.prototype.transforms.functional
import
to_image_pil
from
torchvision.prototype.transforms.utils
import
has_all
,
has_any
IMAGE
=
make_image
(
color_space
=
features
.
ColorSpace
.
RGB
)
...
...
torchvision/prototype/transforms/__init__.py
View file @
4941c6b6
from
torchvision.transforms
import
AutoAugmentPolicy
,
InterpolationMode
# usort: skip
from
.
import
functional
# usort: skip
from
.
import
functional
,
utils
# usort: skip
from
._transform
import
Transform
# usort: skip
from
._presets
import
StereoMatching
# usort: skip
...
...
torchvision/prototype/transforms/_augment.py
View file @
4941c6b6
...
...
@@ -11,7 +11,7 @@ from torchvision.prototype import features
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
from
._transform
import
_RandomApplyTransform
from
.
_
utils
import
has_any
,
query_chw
,
query_spatial_size
from
.utils
import
has_any
,
query_chw
,
query_spatial_size
class
RandomErasing
(
_RandomApplyTransform
):
...
...
torchvision/prototype/transforms/_auto_augment.py
View file @
4941c6b6
...
...
@@ -10,7 +10,8 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F,
from
torchvision.prototype.transforms.functional._meta
import
get_spatial_size
from
torchvision.transforms
import
functional_tensor
as
_FT
from
._utils
import
_isinstance
,
_setup_fill_arg
from
._utils
import
_setup_fill_arg
from
.utils
import
check_type
class
_AutoAugmentBase
(
Transform
):
...
...
@@ -38,7 +39,7 @@ class _AutoAugmentBase(Transform):
image_or_videos
=
[]
for
idx
,
inpt
in
enumerate
(
flat_inputs
):
if
_isinstanc
e
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
if
check_typ
e
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
,
features
.
Video
)):
image_or_videos
.
append
((
idx
,
inpt
))
elif
isinstance
(
inpt
,
unsupported_types
):
raise
TypeError
(
f
"Inputs of type
{
type
(
inpt
).
__name__
}
are not supported by
{
type
(
self
).
__name__
}
()"
)
...
...
torchvision/prototype/transforms/_color.py
View file @
4941c6b6
...
...
@@ -7,7 +7,7 @@ from torchvision.prototype import features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
._transform
import
_RandomApplyTransform
from
.
_
utils
import
query_chw
from
.utils
import
query_chw
class
ColorJitter
(
Transform
):
...
...
torchvision/prototype/transforms/_deprecated.py
View file @
4941c6b6
...
...
@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
.
_
utils
import
query_chw
from
.utils
import
query_chw
class
ToTensor
(
Transform
):
...
...
torchvision/prototype/transforms/_geometry.py
View file @
4941c6b6
...
...
@@ -21,11 +21,8 @@ from ._utils import (
_setup_fill_arg
,
_setup_float_or_seq
,
_setup_size
,
has_all
,
has_any
,
query_bounding_box
,
query_spatial_size
,
)
from
.utils
import
has_all
,
has_any
,
query_bounding_box
,
query_spatial_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
...
torchvision/prototype/transforms/_misc.py
View file @
4941c6b6
...
...
@@ -7,7 +7,8 @@ from torchvision.ops import remove_small_boxes
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
._utils
import
_get_defaultdict
,
_setup_float_or_seq
,
_setup_size
,
has_any
,
query_bounding_box
from
._utils
import
_get_defaultdict
,
_setup_float_or_seq
,
_setup_size
from
.utils
import
has_any
,
query_bounding_box
class
Identity
(
Transform
):
...
...
torchvision/prototype/transforms/_transform.py
View file @
4941c6b6
...
...
@@ -5,7 +5,7 @@ import PIL.Image
import
torch
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision.prototype.transforms.
_
utils
import
_isinstanc
e
from
torchvision.prototype.transforms.utils
import
check_typ
e
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -36,8 +36,7 @@ class Transform(nn.Module):
params
=
self
.
_get_params
(
flat_inputs
)
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
if
_isinstance
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
self
.
_transform
(
inpt
,
params
)
if
check_type
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
]
return
tree_unflatten
(
flat_outputs
,
spec
)
...
...
@@ -80,8 +79,7 @@ class _RandomApplyTransform(Transform):
params
=
self
.
_get_params
(
flat_inputs
)
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
if
_isinstance
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
self
.
_transform
(
inpt
,
params
)
if
check_type
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
]
return
tree_unflatten
(
flat_outputs
,
spec
)
torchvision/prototype/transforms/_utils.py
View file @
4941c6b6
import
functools
import
numbers
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Sequence
,
Tupl
e
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Dict
,
Sequenc
e
,
Type
,
TypeVar
,
Union
import
PIL.Image
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
features
from
torchvision.prototype.features._feature
import
FillType
,
FillTypeJIT
from
torchvision.prototype.transforms.functional._meta
import
get_dimensions
,
get_spatial_size
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
typing_extensions
import
Literal
...
...
@@ -100,65 +96,3 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
def
_check_padding_mode_arg
(
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
])
->
None
:
if
padding_mode
not
in
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]:
raise
ValueError
(
"Padding mode should be either constant, edge, reflect or symmetric"
)
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
features
.
BoundingBox
:
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
features
.
BoundingBox
)]
if
not
bounding_boxes
:
raise
TypeError
(
"No bounding box was found in the sample"
)
elif
len
(
bounding_boxes
)
>
1
:
raise
ValueError
(
"Found multiple bounding boxes in the sample"
)
return
bounding_boxes
.
pop
()
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
Video
))
or
features
.
is_simple_tensor
(
inpt
)
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
elif
len
(
chws
)
>
1
:
raise
ValueError
(
f
"Found multiple CxHxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
chws
))
}
"
)
c
,
h
,
w
=
chws
.
pop
()
return
c
,
h
,
w
def
query_spatial_size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
sizes
=
{
tuple
(
get_spatial_size
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
Video
,
features
.
Mask
,
features
.
BoundingBox
))
or
features
.
is_simple_tensor
(
inpt
)
}
if
not
sizes
:
raise
TypeError
(
"No image, video, mask or bounding box was found in the sample"
)
elif
len
(
sizes
)
>
1
:
raise
ValueError
(
f
"Found multiple HxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
sizes
))
}
"
)
h
,
w
=
sizes
.
pop
()
return
h
,
w
def
_isinstance
(
obj
:
Any
,
types_or_checks
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...])
->
bool
:
for
type_or_check
in
types_or_checks
:
if
isinstance
(
obj
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
obj
):
return
True
return
False
def
has_any
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
inpt
in
flat_inputs
:
if
_isinstance
(
inpt
,
types_or_checks
):
return
True
return
False
def
has_all
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
type_or_check
in
types_or_checks
:
for
inpt
in
flat_inputs
:
if
isinstance
(
inpt
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
inpt
):
break
else
:
return
False
return
True
torchvision/prototype/transforms/utils.py
0 → 100644
View file @
4941c6b6
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
import
PIL.Image
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
features
.
BoundingBox
:
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
features
.
BoundingBox
)]
if
not
bounding_boxes
:
raise
TypeError
(
"No bounding box was found in the sample"
)
elif
len
(
bounding_boxes
)
>
1
:
raise
ValueError
(
"Found multiple bounding boxes in the sample"
)
return
bounding_boxes
.
pop
()
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
Video
))
or
features
.
is_simple_tensor
(
inpt
)
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
elif
len
(
chws
)
>
1
:
raise
ValueError
(
f
"Found multiple CxHxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
chws
))
}
"
)
c
,
h
,
w
=
chws
.
pop
()
return
c
,
h
,
w
def
query_spatial_size
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
]:
sizes
=
{
tuple
(
get_spatial_size
(
inpt
))
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
Video
,
features
.
Mask
,
features
.
BoundingBox
))
or
features
.
is_simple_tensor
(
inpt
)
}
if
not
sizes
:
raise
TypeError
(
"No image, video, mask or bounding box was found in the sample"
)
elif
len
(
sizes
)
>
1
:
raise
ValueError
(
f
"Found multiple HxW dimensions in the sample:
{
sequence_to_str
(
sorted
(
sizes
))
}
"
)
h
,
w
=
sizes
.
pop
()
return
h
,
w
def
check_type
(
obj
:
Any
,
types_or_checks
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...])
->
bool
:
for
type_or_check
in
types_or_checks
:
if
isinstance
(
obj
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
obj
):
return
True
return
False
def
has_any
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
inpt
in
flat_inputs
:
if
check_type
(
inpt
,
types_or_checks
):
return
True
return
False
def
has_all
(
flat_inputs
:
List
[
Any
],
*
types_or_checks
:
Union
[
Type
,
Callable
[[
Any
],
bool
]])
->
bool
:
for
type_or_check
in
types_or_checks
:
for
inpt
in
flat_inputs
:
if
isinstance
(
inpt
,
type_or_check
)
if
isinstance
(
type_or_check
,
type
)
else
type_or_check
(
inpt
):
break
else
:
return
False
return
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment