Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5dd95944
Unverified
Commit
5dd95944
authored
Jan 24, 2023
by
Nicolas Hug
Committed by
GitHub
Jan 24, 2023
Browse files
Remove `color_space` metadata and `ConvertColorSpace()` transform (#7120)
parent
c206a471
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
106 additions
and
506 deletions
+106
-506
test/prototype_common_utils.py
test/prototype_common_utils.py
+26
-29
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+38
-150
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+4
-43
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+11
-15
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+0
-1
test/test_prototype_transforms_utils.py
test/test_prototype_transforms_utils.py
+1
-1
torchvision/prototype/datapoints/__init__.py
torchvision/prototype/datapoints/__init__.py
+1
-1
torchvision/prototype/datapoints/_image.py
torchvision/prototype/datapoints/_image.py
+5
-69
torchvision/prototype/datapoints/_video.py
torchvision/prototype/datapoints/_video.py
+5
-25
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_deprecated.py
torchvision/prototype/transforms/_deprecated.py
+3
-2
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+1
-32
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+0
-4
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+5
-7
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+2
-126
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+3
-0
No files found.
test/prototype_common_utils.py
View file @
5dd95944
...
...
@@ -238,7 +238,6 @@ class TensorLoader:
@
dataclasses
.
dataclass
class
ImageLoader
(
TensorLoader
):
color_space
:
datapoints
.
ColorSpace
spatial_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
num_channels
:
int
=
dataclasses
.
field
(
init
=
False
)
...
...
@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP
=
{
datapoints
.
ColorSpace
.
GRAY
:
1
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
:
2
,
datapoints
.
ColorSpace
.
RGB
:
3
,
datapoints
.
ColorSpace
.
RGB_ALPHA
:
4
,
"
GRAY
"
:
1
,
"
GRAY_ALPHA
"
:
2
,
"
RGB
"
:
3
,
"RGBA"
:
4
,
}
...
...
@@ -265,7 +264,7 @@ def get_num_channels(color_space):
def
make_image_loader
(
size
=
"random"
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
color_space
=
"
RGB
"
,
extra_dims
=
(),
dtype
=
torch
.
float32
,
constant_alpha
=
True
,
...
...
@@ -276,11 +275,11 @@ def make_image_loader(
def
fn
(
shape
,
dtype
,
device
):
max_value
=
get_max_value
(
dtype
)
data
=
torch
.
testing
.
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
,
device
=
device
)
if
color_space
in
{
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB_ALPHA
}
and
constant_alpha
:
if
color_space
in
{
"
GRAY_ALPHA
"
,
"RGBA"
}
and
constant_alpha
:
data
[...,
-
1
,
:,
:]
=
max_value
return
datapoints
.
Image
(
data
,
color_space
=
color_space
)
return
datapoints
.
Image
(
data
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
)
make_image
=
from_loader
(
make_image_loader
)
...
...
@@ -290,10 +289,10 @@ def make_image_loaders(
*
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
"
GRAY
"
,
"
GRAY_ALPHA
"
,
"
RGB
"
,
"RGBA"
,
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
...
...
@@ -306,7 +305,7 @@ def make_image_loaders(
make_images
=
from_loaders
(
make_image_loaders
)
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
torch
.
uint8
):
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
"
RGB
"
,
dtype
=
torch
.
uint8
):
size
=
_parse_spatial_size
(
size
)
num_channels
=
get_num_channels
(
color_space
)
...
...
@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints
.
resize
((
width
,
height
))
.
convert
(
{
datapoints
.
ColorSpace
.
GRAY
:
"L"
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
:
"LA"
,
datapoints
.
ColorSpace
.
RGB
:
"RGB"
,
datapoints
.
ColorSpace
.
RGB_ALPHA
:
"RGBA"
,
"
GRAY
"
:
"L"
,
"
GRAY_ALPHA
"
:
"LA"
,
"
RGB
"
:
"RGB"
,
"RGBA"
:
"RGBA"
,
}[
color_space
]
)
)
image_tensor
=
convert_dtype_image_tensor
(
to_image_tensor
(
image_pil
).
to
(
device
=
device
),
dtype
=
dtype
)
return
datapoints
.
Image
(
image_tensor
,
color_space
=
color_space
)
return
datapoints
.
Image
(
image_tensor
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
)
def
make_image_loaders_for_interpolation
(
sizes
=
((
233
,
147
),),
color_spaces
=
(
datapoints
.
ColorSpace
.
RGB
,),
color_spaces
=
(
"
RGB
"
,),
dtypes
=
(
torch
.
uint8
,),
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
dtype
=
dtypes
):
...
...
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
def
make_video_loader
(
size
=
"random"
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
color_space
=
"
RGB
"
,
num_frames
=
"random"
,
extra_dims
=
(),
dtype
=
torch
.
uint8
,
...
...
@@ -592,12 +591,10 @@ def make_video_loader(
num_frames
=
int
(
torch
.
randint
(
1
,
5
,
()))
if
num_frames
==
"random"
else
num_frames
def
fn
(
shape
,
dtype
,
device
):
video
=
make_image
(
size
=
shape
[
-
2
:],
color_space
=
color_space
,
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
)
return
datapoints
.
Video
(
video
,
color_space
=
color_space
)
video
=
make_image
(
size
=
shape
[
-
2
:],
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
)
return
datapoints
.
Video
(
video
)
return
VideoLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
return
VideoLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
)
make_video
=
from_loader
(
make_video_loader
)
...
...
@@ -607,8 +604,8 @@ def make_video_loaders(
*
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
,
"
GRAY
"
,
"
RGB
"
,
),
num_frames
=
(
1
,
0
,
"random"
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
...
...
test/prototype_transforms_kernel_infos.py
View file @
5dd95944
...
...
@@ -9,7 +9,6 @@ import pytest
import
torch.testing
import
torchvision.ops
import
torchvision.prototype.transforms.functional
as
F
from
common_utils
import
cycle_over
from
datasets_utils
import
combinations_grid
from
prototype_common_utils
import
(
ArgsKwargs
,
...
...
@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size):
def
sample_inputs_resize_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"RGB"
],
dtypes
=
[
torch
.
float32
]):
for
size
in
_get_resize_sizes
(
image_loader
.
spatial_size
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
for
image_loader
,
interpolation
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]),
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
]),
[
F
.
InterpolationMode
.
NEAREST
,
F
.
InterpolationMode
.
BILINEAR
,
...
...
@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
def
sample_inputs_affine_image_tensor
():
make_affine_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
for
image_loader
,
affine_params
in
itertools
.
product
(
make_affine_image_loaders
(),
_DIVERSE_AFFINE_PARAMS
):
...
...
@@ -684,69 +681,6 @@ KERNEL_INFOS.append(
)
def
sample_inputs_convert_color_space_image_tensor
():
color_spaces
=
sorted
(
set
(
datapoints
.
ColorSpace
)
-
{
datapoints
.
ColorSpace
.
OTHER
},
key
=
lambda
color_space
:
color_space
.
value
)
for
old_color_space
,
new_color_space
in
cycle_over
(
color_spaces
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
old_color_space
],
constant_alpha
=
True
):
yield
ArgsKwargs
(
image_loader
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
for
color_space
in
color_spaces
:
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
color_space
],
dtypes
=
[
torch
.
float32
],
constant_alpha
=
True
):
yield
ArgsKwargs
(
image_loader
,
old_color_space
=
color_space
,
new_color_space
=
color_space
)
@
pil_reference_wrapper
def
reference_convert_color_space_image_tensor
(
image_pil
,
old_color_space
,
new_color_space
):
color_space_pil
=
datapoints
.
ColorSpace
.
from_pil_mode
(
image_pil
.
mode
)
if
color_space_pil
!=
old_color_space
:
raise
pytest
.
UsageError
(
f
"Converting the tensor image into an PIL image changed the colorspace "
f
"from
{
old_color_space
}
to
{
color_space_pil
}
"
)
return
F
.
convert_color_space_image_pil
(
image_pil
,
color_space
=
new_color_space
)
def
reference_inputs_convert_color_space_image_tensor
():
for
args_kwargs
in
sample_inputs_convert_color_space_image_tensor
():
(
image_loader
,
*
other_args
),
kwargs
=
args_kwargs
if
len
(
image_loader
.
shape
)
==
3
and
image_loader
.
dtype
==
torch
.
uint8
:
yield
args_kwargs
def
sample_inputs_convert_color_space_video
():
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
]
for
old_color_space
,
new_color_space
in
cycle_over
(
color_spaces
):
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
old_color_space
],
num_frames
=
[
"random"
]):
yield
ArgsKwargs
(
video_loader
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
convert_color_space_image_tensor
,
sample_inputs_fn
=
sample_inputs_convert_color_space_image_tensor
,
reference_fn
=
reference_convert_color_space_image_tensor
,
reference_inputs_fn
=
reference_inputs_convert_color_space_image_tensor
,
closeness_kwargs
=
{
**
pil_reference_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
},
),
KernelInfo
(
F
.
convert_color_space_video
,
sample_inputs_fn
=
sample_inputs_convert_color_space_video
,
),
]
)
def
sample_inputs_vertical_flip_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
dtypes
=
[
torch
.
float32
]):
yield
ArgsKwargs
(
image_loader
)
...
...
@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
def
sample_inputs_rotate_image_tensor
():
make_rotate_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
for
image_loader
in
make_rotate_image_loaders
():
...
...
@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
def
sample_inputs_crop_image_tensor
():
for
image_loader
,
params
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
[
dict
(
top
=
4
,
left
=
3
,
height
=
7
,
width
=
8
),
dict
(
top
=-
1
,
left
=
3
,
height
=
7
,
width
=
8
),
...
...
@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid(
def
sample_inputs_pad_image_tensor
():
make_pad_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
for
image_loader
,
padding
in
itertools
.
product
(
...
...
@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def
sample_inputs_center_crop_image_tensor
():
for
image_loader
,
output_size
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
[
# valid `output_size` types for which cropping is applied to both dimensions
*
[
5
,
(
4
,),
(
2
,
3
),
[
6
],
[
3
,
2
]],
...
...
@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend(
def
sample_inputs_gaussian_blur_image_tensor
():
make_gaussian_blur_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[(
7
,
33
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]
)
make_gaussian_blur_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[(
7
,
33
)],
color_spaces
=
[
"RGB"
])
for
image_loader
,
kernel_size
in
itertools
.
product
(
make_gaussian_blur_image_loaders
(),
[
5
,
(
3
,
3
),
[
3
,
3
]]):
yield
ArgsKwargs
(
image_loader
,
kernel_size
=
kernel_size
)
...
...
@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend(
def
sample_inputs_equalize_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
)
...
...
@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size
=
(
256
,
256
)
for
dtype
,
color_space
,
fn
in
itertools
.
product
(
[
torch
.
uint8
],
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
],
[
"
GRAY
"
,
"
RGB
"
],
[
lambda
shape
,
dtype
,
device
:
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
),
lambda
shape
,
dtype
,
device
:
torch
.
full
(
...
...
@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor():
],
],
):
image_loader
=
ImageLoader
(
fn
,
shape
=
(
get_num_channels
(
color_space
),
*
spatial_size
),
dtype
=
dtype
,
color_space
=
color_space
)
image_loader
=
ImageLoader
(
fn
,
shape
=
(
get_num_channels
(
color_space
),
*
spatial_size
),
dtype
=
dtype
)
yield
ArgsKwargs
(
image_loader
)
...
...
@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend(
def
sample_inputs_invert_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
)
def
reference_inputs_invert_image_tensor
():
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
yield
ArgsKwargs
(
image_loader
)
...
...
@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8]
def
sample_inputs_posterize_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
bits
=
_POSTERIZE_BITS
[
0
])
def
reference_inputs_posterize_image_tensor
():
for
image_loader
,
bits
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_POSTERIZE_BITS
,
):
yield
ArgsKwargs
(
image_loader
,
bits
=
bits
)
...
...
@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype):
def
sample_inputs_solarize_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
threshold
=
next
(
_get_solarize_thresholds
(
image_loader
.
dtype
)))
def
reference_inputs_solarize_image_tensor
():
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
for
threshold
in
_get_solarize_thresholds
(
image_loader
.
dtype
):
yield
ArgsKwargs
(
image_loader
,
threshold
=
threshold
)
...
...
@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend(
def
sample_inputs_autocontrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
)
def
reference_inputs_autocontrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
yield
ArgsKwargs
(
image_loader
)
...
...
@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_sharpness_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
,
(
2
,
2
)],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
color_spaces
=
(
"
GRAY
"
,
"
RGB
"
),
):
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
_ADJUST_SHARPNESS_FACTORS
[
0
])
def
reference_inputs_adjust_sharpness_image_tensor
():
for
image_loader
,
sharpness_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_SHARPNESS_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
sharpness_factor
)
...
...
@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_brightness_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
_ADJUST_BRIGHTNESS_FACTORS
[
0
])
def
reference_inputs_adjust_brightness_image_tensor
():
for
image_loader
,
brightness_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_BRIGHTNESS_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
brightness_factor
)
...
...
@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_contrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
_ADJUST_CONTRAST_FACTORS
[
0
])
def
reference_inputs_adjust_contrast_image_tensor
():
for
image_loader
,
contrast_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_CONTRAST_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
contrast_factor
)
...
...
@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def
sample_inputs_adjust_gamma_image_tensor
():
gamma
,
gain
=
_ADJUST_GAMMA_GAMMAS_GAINS
[
0
]
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
def
reference_inputs_adjust_gamma_image_tensor
():
for
image_loader
,
(
gamma
,
gain
)
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_GAMMA_GAMMAS_GAINS
,
):
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def
sample_inputs_adjust_hue_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
_ADJUST_HUE_FACTORS
[
0
])
def
reference_inputs_adjust_hue_image_tensor
():
for
image_loader
,
hue_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_HUE_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
hue_factor
)
...
...
@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_saturation_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
_ADJUST_SATURATION_FACTORS
[
0
])
def
reference_inputs_adjust_saturation_image_tensor
():
for
image_loader
,
saturation_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_SATURATION_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
saturation_factor
)
...
...
@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor():
for
size
in
_FIVE_TEN_CROP_SIZES
:
for
image_loader
in
make_image_loaders
(
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
],
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
...
...
@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor():
for
size
,
vertical_flip
in
itertools
.
product
(
_FIVE_TEN_CROP_SIZES
,
[
False
,
True
]):
for
image_loader
in
make_image_loaders
(
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
],
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
,
vertical_flip
=
vertical_flip
)
...
...
@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [
def
sample_inputs_normalize_image_tensor
():
for
image_loader
,
(
mean
,
std
)
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
_NORMALIZE_MEANS_STDS
,
):
yield
ArgsKwargs
(
image_loader
,
mean
=
mean
,
std
=
std
)
...
...
@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False):
def
reference_inputs_normalize_image_tensor
():
yield
ArgsKwargs
(
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
datapoints
.
ColorSpace
.
RGB
,
extra_dims
=
[
1
]),
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
"
RGB
"
,
extra_dims
=
[
1
]),
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
],
)
...
...
@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor():
def
sample_inputs_normalize_video
():
mean
,
std
=
_NORMALIZE_MEANS_STDS
[
0
]
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
num_frames
=
[
"random"
],
dtypes
=
[
torch
.
float32
]
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
num_frames
=
[
"random"
],
dtypes
=
[
torch
.
float32
]
):
yield
ArgsKwargs
(
video_loader
,
mean
=
mean
,
std
=
std
)
...
...
@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor():
# conversion cannot be performed safely
continue
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
input_dtype
]
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"RGB"
],
dtypes
=
[
input_dtype
]):
yield
ArgsKwargs
(
image_loader
,
dtype
=
output_dtype
)
...
...
@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def
reference_inputs_uniform_temporal_subsample_video
():
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
num_frames
=
[
10
]):
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
num_frames
=
[
10
]):
for
num_samples
in
range
(
1
,
video_loader
.
shape
[
-
4
]
+
1
):
yield
ArgsKwargs
(
video_loader
,
num_samples
)
...
...
test/test_prototype_transforms.py
View file @
5dd95944
...
...
@@ -161,8 +161,8 @@ class TestSmoke:
itertools
.
chain
.
from_iterable
(
fn
(
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
,
"
GRAY
"
,
"
RGB
"
,
],
dtypes
=
[
torch
.
uint8
],
extra_dims
=
[(),
(
4
,)],
...
...
@@ -192,7 +192,7 @@ class TestSmoke:
(
transforms
.
Normalize
(
mean
=
[
0.0
,
0.0
,
0.0
],
std
=
[
1.0
,
1.0
,
1.0
]),
itertools
.
chain
.
from_iterable
(
fn
(
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
])
fn
(
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
])
for
fn
in
[
make_images
,
make_vanilla_tensor_images
,
...
...
@@ -221,45 +221,6 @@ class TestSmoke:
def
test_random_resized_crop
(
self
,
transform
,
input
):
transform
(
input
)
@
parametrize
(
[
(
transforms
.
ConvertColorSpace
(
color_space
=
new_color_space
,
old_color_space
=
old_color_space
),
itertools
.
chain
.
from_iterable
(
[
fn
(
color_spaces
=
[
old_color_space
])
for
fn
in
(
make_images
,
make_vanilla_tensor_images
,
make_pil_images
,
make_videos
,
)
]
),
)
for
old_color_space
,
new_color_space
in
itertools
.
product
(
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
],
repeat
=
2
,
)
]
)
def
test_convert_color_space
(
self
,
transform
,
input
):
transform
(
input
)
def
test_convert_color_space_unsupported_types
(
self
):
transform
=
transforms
.
ConvertColorSpace
(
color_space
=
datapoints
.
ColorSpace
.
RGB
,
old_color_space
=
datapoints
.
ColorSpace
.
GRAY
)
for
inpt
in
[
make_bounding_box
(
format
=
"XYXY"
),
make_masks
()]:
output
=
transform
(
inpt
)
assert
output
is
inpt
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
class
TestRandomHorizontalFlip
:
...
...
@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop:
transform
=
transforms
.
FixedSizeCrop
(
size
=
crop_size
)
flat_inputs
=
[
make_image
(
size
=
spatial_size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
),
make_image
(
size
=
spatial_size
,
color_space
=
"
RGB
"
),
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
extra_dims
=
batch_shape
),
...
...
test/test_prototype_transforms_consistency.py
View file @
5dd95944
...
...
@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from
torchvision.prototype.transforms.utils
import
query_spatial_size
from
torchvision.transforms
import
functional
as
legacy_F
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
extra_dims
=
[(
4
,)])
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"
RGB
"
],
extra_dims
=
[(
4
,)])
class
ConsistencyConfig
:
...
...
@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]
),
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
"RGB"
]),
supports_pil
=
False
,
),
ConsistencyConfig
(
...
...
@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
num_output_channels
=
1
),
ArgsKwargs
(
num_output_channels
=
3
),
],
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
GRAY
]
),
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
color_spaces
=
[
"RGB"
,
"GRAY"
]),
),
ConsistencyConfig
(
prototype_transforms
.
ConvertDtype
,
...
...
@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs
()],
make_images_kwargs
=
dict
(
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
"
GRAY
"
,
"
GRAY_ALPHA
"
,
"
RGB
"
,
"RGBA"
,
],
extra_dims
=
[()],
),
...
...
@@ -911,7 +907,7 @@ class TestRefDetTransforms:
size
=
(
600
,
800
)
num_objects
=
22
pil_image
=
to_image_pil
(
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
))
pil_image
=
to_image_pil
(
make_image
(
size
=
size
,
color_space
=
"
RGB
"
))
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
@@ -921,7 +917,7 @@ class TestRefDetTransforms:
yield
(
pil_image
,
target
)
tensor_image
=
torch
.
Tensor
(
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
))
tensor_image
=
torch
.
Tensor
(
make_image
(
size
=
size
,
color_space
=
"
RGB
"
))
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
@@ -931,7 +927,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
"
RGB
"
)
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
@@ -1015,7 +1011,7 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
"
RGB
"
,
dtype
=
image_dtype
)
datapoint_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
...
...
test/test_prototype_transforms_functional.py
View file @
5dd95944
...
...
@@ -340,7 +340,6 @@ class TestDispatchers:
"dispatcher"
,
[
F
.
clamp_bounding_box
,
F
.
convert_color_space
,
F
.
get_dimensions
,
F
.
get_image_num_channels
,
F
.
get_image_size
,
...
...
test/test_prototype_transforms_utils.py
View file @
5dd95944
...
...
@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from
torchvision.prototype.transforms.utils
import
has_all
,
has_any
IMAGE
=
make_image
(
color_space
=
datapoints
.
ColorSpace
.
RGB
)
IMAGE
=
make_image
(
color_space
=
"
RGB
"
)
BOUNDING_BOX
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
IMAGE
.
spatial_size
)
MASK
=
make_detection_mask
(
size
=
IMAGE
.
spatial_size
)
...
...
torchvision/prototype/datapoints/__init__.py
View file @
5dd95944
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._datapoint
import
FillType
,
FillTypeJIT
,
InputType
,
InputTypeJIT
from
._image
import
ColorSpace
,
Image
,
ImageType
,
ImageTypeJIT
,
TensorImageType
,
TensorImageTypeJIT
from
._image
import
Image
,
ImageType
,
ImageTypeJIT
,
TensorImageType
,
TensorImageTypeJIT
from
._label
import
Label
,
OneHotLabel
from
._mask
import
Mask
from
._video
import
TensorVideoType
,
TensorVideoTypeJIT
,
Video
,
VideoType
,
VideoTypeJIT
torchvision/prototype/datapoints/_image.py
View file @
5dd95944
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
torch
from
torchvision._utils
import
StrEnum
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
Datapoint
,
FillTypeJIT
class
ColorSpace
(
StrEnum
):
OTHER
=
StrEnum
.
auto
()
GRAY
=
StrEnum
.
auto
()
GRAY_ALPHA
=
StrEnum
.
auto
()
RGB
=
StrEnum
.
auto
()
RGB_ALPHA
=
StrEnum
.
auto
()
@
classmethod
def
from_pil_mode
(
cls
,
mode
:
str
)
->
ColorSpace
:
if
mode
==
"L"
:
return
cls
.
GRAY
elif
mode
==
"LA"
:
return
cls
.
GRAY_ALPHA
elif
mode
==
"RGB"
:
return
cls
.
RGB
elif
mode
==
"RGBA"
:
return
cls
.
RGB_ALPHA
else
:
return
cls
.
OTHER
@
staticmethod
def
from_tensor_shape
(
shape
:
List
[
int
])
->
ColorSpace
:
return
_from_tensor_shape
(
shape
)
def
_from_tensor_shape
(
shape
:
List
[
int
])
->
ColorSpace
:
# Needed as a standalone method for JIT
ndim
=
len
(
shape
)
if
ndim
<
2
:
return
ColorSpace
.
OTHER
elif
ndim
==
2
:
return
ColorSpace
.
GRAY
num_channels
=
shape
[
-
3
]
if
num_channels
==
1
:
return
ColorSpace
.
GRAY
elif
num_channels
==
2
:
return
ColorSpace
.
GRAY_ALPHA
elif
num_channels
==
3
:
return
ColorSpace
.
RGB
elif
num_channels
==
4
:
return
ColorSpace
.
RGB_ALPHA
else
:
return
ColorSpace
.
OTHER
class
Image
(
Datapoint
):
color_space
:
ColorSpace
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
ColorSpace
)
->
Image
:
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
)
->
Image
:
image
=
tensor
.
as_subclass
(
cls
)
image
.
color_space
=
color_space
return
image
def
__new__
(
cls
,
data
:
Any
,
*
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
...
...
@@ -81,26 +29,14 @@ class Image(Datapoint):
elif
tensor
.
ndim
==
2
:
tensor
=
tensor
.
unsqueeze
(
0
)
if
color_space
is
None
:
color_space
=
ColorSpace
.
from_tensor_shape
(
tensor
.
shape
)
# type: ignore[arg-type]
if
color_space
==
ColorSpace
.
OTHER
:
warnings
.
warn
(
"Unable to guess a specific color space. Consider passing it explicitly."
)
elif
isinstance
(
color_space
,
str
):
color_space
=
ColorSpace
.
from_str
(
color_space
.
upper
())
elif
not
isinstance
(
color_space
,
ColorSpace
):
raise
ValueError
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
)
return
cls
.
_wrap
(
tensor
)
@
classmethod
def
wrap_like
(
cls
,
other
:
Image
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
Optional
[
ColorSpace
]
=
None
)
->
Image
:
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
if
color_space
is
not
None
else
other
.
color_space
,
)
def
wrap_like
(
cls
,
other
:
Image
,
tensor
:
torch
.
Tensor
)
->
Image
:
return
cls
.
_wrap
(
tensor
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
color_space
=
self
.
color_space
)
return
self
.
_make_repr
()
@
property
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
...
...
torchvision/prototype/datapoints/_video.py
View file @
5dd95944
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
Datapoint
,
FillTypeJIT
from
._image
import
ColorSpace
class
Video
(
Datapoint
):
color_space
:
ColorSpace
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
ColorSpace
)
->
Video
:
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
)
->
Video
:
video
=
tensor
.
as_subclass
(
cls
)
video
.
color_space
=
color_space
return
video
def
__new__
(
cls
,
data
:
Any
,
*
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
...
...
@@ -31,28 +25,14 @@ class Video(Datapoint):
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
data
.
ndim
<
4
:
raise
ValueError
video
=
super
().
__new__
(
cls
,
data
,
requires_grad
=
requires_grad
)
if
color_space
is
None
:
color_space
=
ColorSpace
.
from_tensor_shape
(
video
.
shape
)
# type: ignore[arg-type]
if
color_space
==
ColorSpace
.
OTHER
:
warnings
.
warn
(
"Unable to guess a specific color space. Consider passing it explicitly."
)
elif
isinstance
(
color_space
,
str
):
color_space
=
ColorSpace
.
from_str
(
color_space
.
upper
())
elif
not
isinstance
(
color_space
,
ColorSpace
):
raise
ValueError
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
)
return
cls
.
_wrap
(
tensor
)
@
classmethod
def
wrap_like
(
cls
,
other
:
Video
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
Optional
[
ColorSpace
]
=
None
)
->
Video
:
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
if
color_space
is
not
None
else
other
.
color_space
,
)
def
wrap_like
(
cls
,
other
:
Video
,
tensor
:
torch
.
Tensor
)
->
Video
:
return
cls
.
_wrap
(
tensor
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
color_space
=
self
.
color_space
)
return
self
.
_make_repr
()
@
property
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
...
...
torchvision/prototype/transforms/__init__.py
View file @
5dd95944
...
...
@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter
,
TenCrop
,
)
from
._meta
import
ClampBoundingBoxes
,
ConvertBoundingBoxFormat
,
ConvertColorSpace
,
ConvertDtype
,
ConvertImageDtype
from
._meta
import
ClampBoundingBoxes
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._misc
import
(
GaussianBlur
,
Identity
,
...
...
torchvision/prototype/transforms/_deprecated.py
View file @
5dd95944
...
...
@@ -28,6 +28,7 @@ class ToTensor(Transform):
return
_F
.
to_tensor
(
inpt
)
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class
Grayscale
(
Transform
):
_transformed_types
=
(
datapoints
.
Image
,
...
...
@@ -62,7 +63,7 @@ class Grayscale(Transform):
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
self
.
num_output_channels
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
,
color_space
=
datapoints
.
ColorSpace
.
GRAY
)
# type: ignore[arg-type]
output
=
inpt
.
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
...
...
@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform):
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
,
color_space
=
datapoints
.
ColorSpace
.
GRAY
)
# type: ignore[arg-type]
output
=
inpt
.
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
torchvision/prototype/transforms/_meta.py
View file @
5dd95944
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
PIL.Image
from
typing
import
Any
,
Dict
,
Union
import
torch
...
...
@@ -46,35 +44,6 @@ class ConvertDtype(Transform):
ConvertImageDtype
=
ConvertDtype
class
ConvertColorSpace
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
)
def
__init__
(
self
,
color_space
:
Union
[
str
,
datapoints
.
ColorSpace
],
old_color_space
:
Optional
[
Union
[
str
,
datapoints
.
ColorSpace
]]
=
None
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
color_space
,
str
):
color_space
=
datapoints
.
ColorSpace
.
from_str
(
color_space
)
self
.
color_space
=
color_space
if
isinstance
(
old_color_space
,
str
):
old_color_space
=
datapoints
.
ColorSpace
.
from_str
(
old_color_space
)
self
.
old_color_space
=
old_color_space
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
return
F
.
convert_color_space
(
inpt
,
color_space
=
self
.
color_space
,
old_color_space
=
self
.
old_color_space
)
class
ClampBoundingBoxes
(
Transform
):
_transformed_types
=
(
datapoints
.
BoundingBox
,)
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
5dd95944
...
...
@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip
from
._meta
import
(
clamp_bounding_box
,
convert_format_bounding_box
,
convert_color_space_image_tensor
,
convert_color_space_image_pil
,
convert_color_space_video
,
convert_color_space
,
convert_dtype_image_tensor
,
convert_dtype
,
convert_dtype_video
,
...
...
torchvision/prototype/transforms/functional/_deprecated.py
View file @
5dd95944
...
...
@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
else
:
old_color_space
=
None
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
old_color_space
=
None
# TODO: remove when un-deprecating
if
not
(
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
))
and
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)
):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
5dd95944
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
PIL.Image
import
torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.datapoints
import
BoundingBoxFormat
,
ColorSpace
from
torchvision.prototype.datapoints
import
BoundingBoxFormat
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
...
...
@@ -225,29 +225,6 @@ def clamp_bounding_box(
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
def
_strip_alpha
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
if
not
torch
.
all
(
alpha
==
_max_value
(
alpha
.
dtype
)):
raise
RuntimeError
(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return
image
def
_add_alpha
(
image
:
torch
.
Tensor
,
alpha
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
alpha
is
None
:
shape
=
list
(
image
.
shape
)
shape
[
-
3
]
=
1
alpha
=
torch
.
full
(
shape
,
_max_value
(
image
.
dtype
),
dtype
=
image
.
dtype
,
device
=
image
.
device
)
return
torch
.
cat
((
image
,
alpha
),
dim
=-
3
)
def
_gray_to_rgb
(
grayscale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeats
=
[
1
]
*
grayscale
.
ndim
repeats
[
-
3
]
=
3
return
grayscale
.
repeat
(
repeats
)
def
_rgb_to_gray
(
image
:
torch
.
Tensor
,
cast
:
bool
=
True
)
->
torch
.
Tensor
:
r
,
g
,
b
=
image
.
unbind
(
dim
=-
3
)
l_img
=
r
.
mul
(
0.2989
).
add_
(
g
,
alpha
=
0.587
).
add_
(
b
,
alpha
=
0.114
)
...
...
@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
return
l_img
def
convert_color_space_image_tensor
(
image
:
torch
.
Tensor
,
old_color_space
:
ColorSpace
,
new_color_space
:
ColorSpace
)
->
torch
.
Tensor
:
if
new_color_space
==
old_color_space
:
return
image
if
old_color_space
==
ColorSpace
.
OTHER
or
new_color_space
==
ColorSpace
.
OTHER
:
raise
RuntimeError
(
f
"Conversion to or from
{
ColorSpace
.
OTHER
}
is not supported."
)
if
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
return
_add_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
RGB
:
return
_gray_to_rgb
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
return
_add_alpha
(
_gray_to_rgb
(
image
))
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_strip_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
RGB
:
return
_gray_to_rgb
(
_strip_alpha
(
image
))
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
return
_add_alpha
(
_gray_to_rgb
(
image
),
alpha
)
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_rgb_to_gray
(
image
)
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
return
_add_alpha
(
_rgb_to_gray
(
image
))
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
return
_add_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_rgb_to_gray
(
_strip_alpha
(
image
))
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
return
_add_alpha
(
_rgb_to_gray
(
image
),
alpha
)
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
RGB
:
return
_strip_alpha
(
image
)
else
:
raise
RuntimeError
(
f
"Conversion from
{
old_color_space
}
to
{
new_color_space
}
is not supported."
)
_COLOR_SPACE_TO_PIL_MODE
=
{
ColorSpace
.
GRAY
:
"L"
,
ColorSpace
.
GRAY_ALPHA
:
"LA"
,
ColorSpace
.
RGB
:
"RGB"
,
ColorSpace
.
RGB_ALPHA
:
"RGBA"
,
}
@
torch
.
jit
.
unused
def
convert_color_space_image_pil
(
image
:
PIL
.
Image
.
Image
,
color_space
:
ColorSpace
)
->
PIL
.
Image
.
Image
:
old_mode
=
image
.
mode
try
:
new_mode
=
_COLOR_SPACE_TO_PIL_MODE
[
color_space
]
except
KeyError
:
raise
ValueError
(
f
"Conversion from
{
ColorSpace
.
from_pil_mode
(
old_mode
)
}
to
{
color_space
}
is not supported."
)
if
image
.
mode
==
new_mode
:
return
image
return
image
.
convert
(
new_mode
)
def
convert_color_space_video
(
video
:
torch
.
Tensor
,
old_color_space
:
ColorSpace
,
new_color_space
:
ColorSpace
)
->
torch
.
Tensor
:
return
convert_color_space_image_tensor
(
video
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
def
convert_color_space
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
color_space
:
ColorSpace
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
old_color_space
is
None
:
raise
RuntimeError
(
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return
convert_color_space_image_tensor
(
inpt
,
old_color_space
=
old_color_space
,
new_color_space
=
color_space
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
convert_color_space_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_color_space
=
inpt
.
color_space
,
new_color_space
=
color_space
)
return
datapoints
.
Image
.
wrap_like
(
inpt
,
output
,
color_space
=
color_space
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
convert_color_space_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_color_space
=
inpt
.
color_space
,
new_color_space
=
color_space
)
return
datapoints
.
Video
.
wrap_like
(
inpt
,
output
,
color_space
=
color_space
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
convert_color_space_image_pil
(
inpt
,
color_space
=
color_space
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
uint8
:
return
8
...
...
torchvision/transforms/functional.py
View file @
5dd95944
...
...
@@ -1234,6 +1234,9 @@ def affine(
return
F_t
.
affine
(
img
,
matrix
=
matrix
,
interpolation
=
interpolation
.
value
,
fill
=
fill
)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig. Anyway, this can be deprecated as we migrate to V2.
@
torch
.
jit
.
unused
def
to_grayscale
(
img
,
num_output_channels
=
1
):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment