Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7de63171
Unverified
Commit
7de63171
authored
Aug 26, 2022
by
Philip Meier
Committed by
GitHub
Aug 26, 2022
Browse files
move simple_tensor to features module (#6507)
* move simple_tensor to features module * fix test
parent
13ea9018
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
46 additions
and
51 deletions
+46
-51
test/test_prototype_transforms_utils.py
test/test_prototype_transforms_utils.py
+4
-4
torchvision/prototype/features/__init__.py
torchvision/prototype/features/__init__.py
+1
-1
torchvision/prototype/features/_feature.py
torchvision/prototype/features/_feature.py
+4
-0
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+6
-6
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+2
-2
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+2
-2
torchvision/prototype/transforms/_deprecated.py
torchvision/prototype/transforms/_deprecated.py
+4
-4
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+5
-5
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+3
-5
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+3
-3
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+7
-3
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+2
-4
torchvision/prototype/transforms/_utils.py
torchvision/prototype/transforms/_utils.py
+2
-9
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+1
-3
No files found.
test/test_prototype_transforms_utils.py
View file @
7de63171
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
test_prototype_transforms_functional
import
make_bounding_box
,
make_image
,
make_segmentation_mask
from
test_prototype_transforms_functional
import
make_bounding_box
,
make_image
,
make_segmentation_mask
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms._utils
import
has_all
,
has_any
,
is_simple_tensor
from
torchvision.prototype.transforms._utils
import
has_all
,
has_any
from
torchvision.prototype.transforms.functional
import
to_image_pil
from
torchvision.prototype.transforms.functional
import
to_image_pil
...
@@ -36,9 +36,9 @@ SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
...
@@ -36,9 +36,9 @@ SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
obj
:
isinstance
(
obj
,
features
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
obj
:
isinstance
(
obj
,
features
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
SEGMENTATION_MASK
),
(
lambda
_
:
True
,),
True
),
((
IMAGE
,),
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
),
True
),
((
IMAGE
,),
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
),
True
),
((
torch
.
Tensor
(
IMAGE
),),
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
),
True
),
((
torch
.
Tensor
(
IMAGE
),),
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
),
True
),
((
to_image_pil
(
IMAGE
),),
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
),
True
),
((
to_image_pil
(
IMAGE
),),
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
),
True
),
],
],
)
)
def
test_has_any
(
sample
,
types
,
expected
):
def
test_has_any
(
sample
,
types
,
expected
):
...
...
torchvision/prototype/features/__init__.py
View file @
7de63171
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._encoded
import
EncodedData
,
EncodedImage
,
EncodedVideo
from
._encoded
import
EncodedData
,
EncodedImage
,
EncodedVideo
from
._feature
import
_Feature
from
._feature
import
_Feature
,
is_simple_tensor
from
._image
import
ColorSpace
,
Image
from
._image
import
ColorSpace
,
Image
from
._label
import
Label
,
OneHotLabel
from
._label
import
Label
,
OneHotLabel
from
._segmentation_mask
import
SegmentationMask
from
._segmentation_mask
import
SegmentationMask
torchvision/prototype/features/_feature.py
View file @
7de63171
...
@@ -10,6 +10,10 @@ from torchvision.transforms import InterpolationMode
...
@@ -10,6 +10,10 @@ from torchvision.transforms import InterpolationMode
F
=
TypeVar
(
"F"
,
bound
=
"_Feature"
)
F
=
TypeVar
(
"F"
,
bound
=
"_Feature"
)
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
_Feature
)
class
_Feature
(
torch
.
Tensor
):
class
_Feature
(
torch
.
Tensor
):
__F
:
Optional
[
ModuleType
]
=
None
__F
:
Optional
[
ModuleType
]
=
None
...
...
torchvision/prototype/transforms/_augment.py
View file @
7de63171
...
@@ -13,7 +13,7 @@ from torchvision.prototype.transforms import functional as F
...
@@ -13,7 +13,7 @@ from torchvision.prototype.transforms import functional as F
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
._utils
import
has_any
,
is_simple_tensor
,
query_chw
from
._utils
import
has_any
,
query_chw
class
RandomErasing
(
_RandomApplyTransform
):
class
RandomErasing
(
_RandomApplyTransform
):
...
@@ -102,7 +102,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
...
@@ -102,7 +102,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
if
not
(
has_any
(
inputs
,
features
.
Image
,
is_simple_tensor
)
and
has_any
(
inputs
,
features
.
OneHotLabel
)):
if
not
(
has_any
(
inputs
,
features
.
Image
,
features
.
is_simple_tensor
)
and
has_any
(
inputs
,
features
.
OneHotLabel
)):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() is only defined for tensor images and one-hot labels."
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() is only defined for tensor images and one-hot labels."
)
if
has_any
(
inputs
,
features
.
BoundingBox
,
features
.
SegmentationMask
,
features
.
Label
):
if
has_any
(
inputs
,
features
.
BoundingBox
,
features
.
SegmentationMask
,
features
.
Label
):
raise
TypeError
(
raise
TypeError
(
...
@@ -124,7 +124,7 @@ class RandomMixup(_BaseMixupCutmix):
...
@@ -124,7 +124,7 @@ class RandomMixup(_BaseMixupCutmix):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
lam
=
params
[
"lam"
]
lam
=
params
[
"lam"
]
if
isinstance
(
inpt
,
features
.
Image
)
or
is_simple_tensor
(
inpt
):
if
isinstance
(
inpt
,
features
.
Image
)
or
features
.
is_simple_tensor
(
inpt
):
if
inpt
.
ndim
<
4
:
if
inpt
.
ndim
<
4
:
raise
ValueError
(
"Need a batch of images"
)
raise
ValueError
(
"Need a batch of images"
)
output
=
inpt
.
clone
()
output
=
inpt
.
clone
()
...
@@ -164,7 +164,7 @@ class RandomCutmix(_BaseMixupCutmix):
...
@@ -164,7 +164,7 @@ class RandomCutmix(_BaseMixupCutmix):
return
dict
(
box
=
box
,
lam_adjusted
=
lam_adjusted
)
return
dict
(
box
=
box
,
lam_adjusted
=
lam_adjusted
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
inpt
,
features
.
Image
)
or
is_simple_tensor
(
inpt
):
if
isinstance
(
inpt
,
features
.
Image
)
or
features
.
is_simple_tensor
(
inpt
):
box
=
params
[
"box"
]
box
=
params
[
"box"
]
if
inpt
.
ndim
<
4
:
if
inpt
.
ndim
<
4
:
raise
ValueError
(
"Need a batch of images"
)
raise
ValueError
(
"Need a batch of images"
)
...
@@ -276,7 +276,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
...
@@ -276,7 +276,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
for
obj
in
flat_sample
:
for
obj
in
flat_sample
:
if
isinstance
(
obj
,
features
.
Image
)
or
is_simple_tensor
(
obj
):
if
isinstance
(
obj
,
features
.
Image
)
or
features
.
is_simple_tensor
(
obj
):
images
.
append
(
obj
)
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
pil_to_tensor
(
obj
))
images
.
append
(
pil_to_tensor
(
obj
))
...
@@ -310,7 +310,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
...
@@ -310,7 +310,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
flat_sample
[
i
]
=
F
.
to_image_pil
(
output_images
[
c0
])
flat_sample
[
i
]
=
F
.
to_image_pil
(
output_images
[
c0
])
c0
+=
1
c0
+=
1
elif
is_simple_tensor
(
obj
):
elif
features
.
is_simple_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
flat_sample
[
i
]
=
output_images
[
c0
]
c0
+=
1
c0
+=
1
elif
isinstance
(
obj
,
features
.
BoundingBox
):
elif
isinstance
(
obj
,
features
.
BoundingBox
):
...
...
torchvision/prototype/transforms/_auto_augment.py
View file @
7de63171
...
@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, Transform
...
@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from
torchvision.transforms.autoaugment
import
AutoAugmentPolicy
from
torchvision.transforms.autoaugment
import
AutoAugmentPolicy
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
._utils
import
_isinstance
,
get_chw
,
is_simple_tensor
from
._utils
import
_isinstance
,
get_chw
K
=
TypeVar
(
"K"
)
K
=
TypeVar
(
"K"
)
V
=
TypeVar
(
"V"
)
V
=
TypeVar
(
"V"
)
...
@@ -44,7 +44,7 @@ class _AutoAugmentBase(Transform):
...
@@ -44,7 +44,7 @@ class _AutoAugmentBase(Transform):
sample_flat
,
_
=
tree_flatten
(
sample
)
sample_flat
,
_
=
tree_flatten
(
sample
)
images
=
[]
images
=
[]
for
id
,
inpt
in
enumerate
(
sample_flat
):
for
id
,
inpt
in
enumerate
(
sample_flat
):
if
_isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
)):
if
_isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
)):
images
.
append
((
id
,
inpt
))
images
.
append
((
id
,
inpt
))
elif
isinstance
(
inpt
,
unsupported_types
):
elif
isinstance
(
inpt
,
unsupported_types
):
raise
TypeError
(
f
"Inputs of type
{
type
(
inpt
).
__name__
}
are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"Inputs of type
{
type
(
inpt
).
__name__
}
are not supported by
{
type
(
self
).
__name__
}
()"
)
...
...
torchvision/prototype/transforms/_color.py
View file @
7de63171
...
@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform
...
@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
._utils
import
is_simple_tensor
,
query_chw
from
._utils
import
query_chw
T
=
TypeVar
(
"T"
,
features
.
Image
,
torch
.
Tensor
,
PIL
.
Image
.
Image
)
T
=
TypeVar
(
"T"
,
features
.
Image
,
torch
.
Tensor
,
PIL
.
Image
.
Image
)
...
@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
...
@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
)
)
def
_permute_channels
(
self
,
inpt
:
Any
,
*
,
permutation
:
torch
.
Tensor
)
->
Any
:
def
_permute_channels
(
self
,
inpt
:
Any
,
*
,
permutation
:
torch
.
Tensor
)
->
Any
:
if
not
(
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
))
or
is_simple_tensor
(
inpt
)):
if
not
(
isinstance
(
inpt
,
(
features
.
Image
,
PIL
.
Image
.
Image
))
or
features
.
is_simple_tensor
(
inpt
)):
return
inpt
return
inpt
image
=
inpt
image
=
inpt
...
...
torchvision/prototype/transforms/_deprecated.py
View file @
7de63171
...
@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
...
@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
._utils
import
is_simple_tensor
,
query_chw
from
._utils
import
query_chw
class
ToTensor
(
Transform
):
class
ToTensor
(
Transform
):
...
@@ -43,7 +43,7 @@ class PILToTensor(Transform):
...
@@ -43,7 +43,7 @@ class PILToTensor(Transform):
class
ToPILImage
(
Transform
):
class
ToPILImage
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
features
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
features
.
is_simple_tensor
,
features
.
Image
,
np
.
ndarray
)
def
__init__
(
self
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
def
__init__
(
self
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -58,7 +58,7 @@ class ToPILImage(Transform):
...
@@ -58,7 +58,7 @@ class ToPILImage(Transform):
class
Grayscale
(
Transform
):
class
Grayscale
(
Transform
):
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
num_output_channels
:
Literal
[
1
,
3
]
=
1
)
->
None
:
def
__init__
(
self
,
num_output_channels
:
Literal
[
1
,
3
]
=
1
)
->
None
:
deprecation_msg
=
(
deprecation_msg
=
(
...
@@ -86,7 +86,7 @@ class Grayscale(Transform):
...
@@ -86,7 +86,7 @@ class Grayscale(Transform):
class
RandomGrayscale
(
_RandomApplyTransform
):
class
RandomGrayscale
(
_RandomApplyTransform
):
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
p
:
float
=
0.1
)
->
None
:
def
__init__
(
self
,
p
:
float
=
0.1
)
->
None
:
warnings
.
warn
(
warnings
.
warn
(
...
...
torchvision/prototype/transforms/_geometry.py
View file @
7de63171
...
@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
...
@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
._utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_box
,
query_chw
from
._utils
import
has_all
,
has_any
,
query_bounding_box
,
query_chw
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
@@ -156,7 +156,7 @@ class FiveCrop(Transform):
...
@@ -156,7 +156,7 @@ class FiveCrop(Transform):
torch.Size([5])
torch.Size([5])
"""
"""
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -176,7 +176,7 @@ class TenCrop(Transform):
...
@@ -176,7 +176,7 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
"""
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
PIL
.
Image
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
vertical_flip
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
vertical_flip
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -696,7 +696,7 @@ class RandomIoUCrop(Transform):
...
@@ -696,7 +696,7 @@ class RandomIoUCrop(Transform):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
if
not
(
if
not
(
has_all
(
inputs
,
features
.
BoundingBox
)
has_all
(
inputs
,
features
.
BoundingBox
)
and
has_any
(
inputs
,
PIL
.
Image
.
Image
,
features
.
Image
,
is_simple_tensor
)
and
has_any
(
inputs
,
PIL
.
Image
.
Image
,
features
.
Image
,
features
.
is_simple_tensor
)
and
has_any
(
inputs
,
features
.
Label
,
features
.
OneHotLabel
)
and
has_any
(
inputs
,
features
.
Label
,
features
.
OneHotLabel
)
):
):
raise
TypeError
(
raise
TypeError
(
...
@@ -847,7 +847,7 @@ class FixedSizeCrop(Transform):
...
@@ -847,7 +847,7 @@ class FixedSizeCrop(Transform):
return
inpt
return
inpt
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
if
not
has_any
(
inputs
,
PIL
.
Image
.
Image
,
features
.
Image
,
is_simple_tensor
):
if
not
has_any
(
inputs
,
PIL
.
Image
.
Image
,
features
.
Image
,
features
.
is_simple_tensor
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain an tensor or PIL image."
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain an tensor or PIL image."
)
if
has_any
(
inputs
,
features
.
BoundingBox
)
and
not
has_any
(
inputs
,
features
.
Label
,
features
.
OneHotLabel
):
if
has_any
(
inputs
,
features
.
BoundingBox
)
and
not
has_any
(
inputs
,
features
.
Label
,
features
.
OneHotLabel
):
...
...
torchvision/prototype/transforms/_meta.py
View file @
7de63171
...
@@ -7,8 +7,6 @@ from torchvision.prototype import features
...
@@ -7,8 +7,6 @@ from torchvision.prototype import features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.transforms.functional
import
convert_image_dtype
from
torchvision.transforms.functional
import
convert_image_dtype
from
._utils
import
is_simple_tensor
class
ConvertBoundingBoxFormat
(
Transform
):
class
ConvertBoundingBoxFormat
(
Transform
):
_transformed_types
=
(
features
.
BoundingBox
,)
_transformed_types
=
(
features
.
BoundingBox
,)
...
@@ -25,7 +23,7 @@ class ConvertBoundingBoxFormat(Transform):
...
@@ -25,7 +23,7 @@ class ConvertBoundingBoxFormat(Transform):
class
ConvertImageDtype
(
Transform
):
class
ConvertImageDtype
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
features
.
Image
)
_transformed_types
=
(
features
.
is_simple_tensor
,
features
.
Image
)
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -33,11 +31,11 @@ class ConvertImageDtype(Transform):
...
@@ -33,11 +31,11 @@ class ConvertImageDtype(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
output
=
convert_image_dtype
(
inpt
,
dtype
=
self
.
dtype
)
output
=
convert_image_dtype
(
inpt
,
dtype
=
self
.
dtype
)
return
output
if
is_simple_tensor
(
inpt
)
else
features
.
Image
.
new_like
(
inpt
,
output
,
dtype
=
self
.
dtype
)
return
output
if
features
.
is_simple_tensor
(
inpt
)
else
features
.
Image
.
new_like
(
inpt
,
output
,
dtype
=
self
.
dtype
)
class
ConvertColorSpace
(
Transform
):
class
ConvertColorSpace
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
features
.
Image
,
PIL
.
Image
.
Image
)
_transformed_types
=
(
features
.
is_simple_tensor
,
features
.
Image
,
PIL
.
Image
.
Image
)
def
__init__
(
def
__init__
(
self
,
self
,
...
...
torchvision/prototype/transforms/_misc.py
View file @
7de63171
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
torchvision.ops
import
remove_small_boxes
from
torchvision.ops
import
remove_small_boxes
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.prototype.transforms._utils
import
has_any
,
is_simple_tensor
,
query_bounding_box
from
torchvision.prototype.transforms._utils
import
has_any
,
query_bounding_box
from
torchvision.transforms.transforms
import
_setup_size
from
torchvision.transforms.transforms
import
_setup_size
...
@@ -38,7 +38,7 @@ class Lambda(Transform):
...
@@ -38,7 +38,7 @@ class Lambda(Transform):
class
LinearTransformation
(
Transform
):
class
LinearTransformation
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
features
.
Image
)
_transformed_types
=
(
features
.
is_simple_tensor
,
features
.
Image
)
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
super
().
__init__
()
super
().
__init__
()
...
@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
...
@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
class
Normalize
(
Transform
):
class
Normalize
(
Transform
):
_transformed_types
=
(
features
.
Image
,
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
]):
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
]):
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/prototype/transforms/_transform.py
View file @
7de63171
...
@@ -5,15 +5,19 @@ import PIL.Image
...
@@ -5,15 +5,19 @@ import PIL.Image
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision.prototype
.features
import
_F
eature
from
torchvision.prototype
import
f
eature
s
from
torchvision.prototype.transforms._utils
import
_isinstance
,
is_simple_tensor
from
torchvision.prototype.transforms._utils
import
_isinstance
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
class
Transform
(
nn
.
Module
):
class
Transform
(
nn
.
Module
):
# Class attribute defining transformed types. Other types are passed-through without any transformation
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...]
=
(
is_simple_tensor
,
_Feature
,
PIL
.
Image
.
Image
)
_transformed_types
:
Tuple
[
Union
[
Type
,
Callable
[[
Any
],
bool
]],
...]
=
(
features
.
is_simple_tensor
,
features
.
_Feature
,
PIL
.
Image
.
Image
,
)
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/prototype/transforms/_type_conversion.py
View file @
7de63171
...
@@ -7,8 +7,6 @@ from torch.nn.functional import one_hot
...
@@ -7,8 +7,6 @@ from torch.nn.functional import one_hot
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
._utils
import
is_simple_tensor
class
DecodeImage
(
Transform
):
class
DecodeImage
(
Transform
):
_transformed_types
=
(
features
.
EncodedImage
,)
_transformed_types
=
(
features
.
EncodedImage
,)
...
@@ -39,14 +37,14 @@ class LabelToOneHot(Transform):
...
@@ -39,14 +37,14 @@ class LabelToOneHot(Transform):
class
ToImageTensor
(
Transform
):
class
ToImageTensor
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
features
.
is_simple_tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
features
.
Image
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
features
.
Image
:
return
F
.
to_image_tensor
(
inpt
)
return
F
.
to_image_tensor
(
inpt
)
class
ToImagePIL
(
Transform
):
class
ToImagePIL
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
features
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
features
.
is_simple_tensor
,
features
.
Image
,
np
.
ndarray
)
def
__init__
(
self
,
*
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
def
__init__
(
self
,
*
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/prototype/transforms/_utils.py
View file @
7de63171
...
@@ -23,7 +23,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
...
@@ -23,7 +23,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
if
isinstance
(
image
,
features
.
Image
):
if
isinstance
(
image
,
features
.
Image
):
channels
=
image
.
num_channels
channels
=
image
.
num_channels
height
,
width
=
image
.
image_size
height
,
width
=
image
.
image_size
elif
is_simple_tensor
(
image
):
elif
features
.
is_simple_tensor
(
image
):
channels
,
height
,
width
=
get_dimensions_image_tensor
(
image
)
channels
,
height
,
width
=
get_dimensions_image_tensor
(
image
)
elif
isinstance
(
image
,
PIL
.
Image
.
Image
):
elif
isinstance
(
image
,
PIL
.
Image
.
Image
):
channels
,
height
,
width
=
get_dimensions_image_pil
(
image
)
channels
,
height
,
width
=
get_dimensions_image_pil
(
image
)
...
@@ -37,7 +37,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
...
@@ -37,7 +37,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
chws
=
{
chws
=
{
get_chw
(
item
)
get_chw
(
item
)
for
item
in
flat_sample
for
item
in
flat_sample
if
isinstance
(
item
,
(
features
.
Image
,
PIL
.
Image
.
Image
))
or
is_simple_tensor
(
item
)
if
isinstance
(
item
,
(
features
.
Image
,
PIL
.
Image
.
Image
))
or
features
.
is_simple_tensor
(
item
)
}
}
if
not
chws
:
if
not
chws
:
raise
TypeError
(
"No image was found in the sample"
)
raise
TypeError
(
"No image was found in the sample"
)
...
@@ -70,10 +70,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
...
@@ -70,10 +70,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
else
:
else
:
return
False
return
False
return
True
return
True
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
# also move it `features` since it literally checks for the _Feature type.
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
features
.
_Feature
)
torchvision/prototype/transforms/functional/_deprecated.py
View file @
7de63171
...
@@ -6,8 +6,6 @@ import PIL.Image
...
@@ -6,8 +6,6 @@ import PIL.Image
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
from
.._utils
import
is_simple_tensor
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
...
@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
...
@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
inpt
:
Any
,
num_output_channels
:
int
=
1
)
->
Any
:
def
rgb_to_grayscale
(
inpt
:
Any
,
num_output_channels
:
int
=
1
)
->
Any
:
old_color_space
=
features
.
Image
.
guess_color_space
(
inpt
)
if
is_simple_tensor
(
inpt
)
else
None
old_color_space
=
features
.
Image
.
guess_color_space
(
inpt
)
if
features
.
is_simple_tensor
(
inpt
)
else
None
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
replacement
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment