Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5dd95944
Unverified
Commit
5dd95944
authored
Jan 24, 2023
by
Nicolas Hug
Committed by
GitHub
Jan 24, 2023
Browse files
Remove `color_space` metadata and `ConvertColorSpace()` transform (#7120)
parent
c206a471
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
106 additions
and
506 deletions
+106
-506
test/prototype_common_utils.py
test/prototype_common_utils.py
+26
-29
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+38
-150
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+4
-43
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+11
-15
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+0
-1
test/test_prototype_transforms_utils.py
test/test_prototype_transforms_utils.py
+1
-1
torchvision/prototype/datapoints/__init__.py
torchvision/prototype/datapoints/__init__.py
+1
-1
torchvision/prototype/datapoints/_image.py
torchvision/prototype/datapoints/_image.py
+5
-69
torchvision/prototype/datapoints/_video.py
torchvision/prototype/datapoints/_video.py
+5
-25
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_deprecated.py
torchvision/prototype/transforms/_deprecated.py
+3
-2
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+1
-32
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+0
-4
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+5
-7
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+2
-126
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+3
-0
No files found.
test/prototype_common_utils.py
View file @
5dd95944
...
@@ -238,7 +238,6 @@ class TensorLoader:
...
@@ -238,7 +238,6 @@ class TensorLoader:
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ImageLoader
(
TensorLoader
):
class
ImageLoader
(
TensorLoader
):
color_space
:
datapoints
.
ColorSpace
spatial_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
spatial_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
num_channels
:
int
=
dataclasses
.
field
(
init
=
False
)
num_channels
:
int
=
dataclasses
.
field
(
init
=
False
)
...
@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader):
...
@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP
=
{
NUM_CHANNELS_MAP
=
{
datapoints
.
ColorSpace
.
GRAY
:
1
,
"
GRAY
"
:
1
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
:
2
,
"
GRAY_ALPHA
"
:
2
,
datapoints
.
ColorSpace
.
RGB
:
3
,
"
RGB
"
:
3
,
datapoints
.
ColorSpace
.
RGB_ALPHA
:
4
,
"RGBA"
:
4
,
}
}
...
@@ -265,7 +264,7 @@ def get_num_channels(color_space):
...
@@ -265,7 +264,7 @@ def get_num_channels(color_space):
def
make_image_loader
(
def
make_image_loader
(
size
=
"random"
,
size
=
"random"
,
*
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
color_space
=
"
RGB
"
,
extra_dims
=
(),
extra_dims
=
(),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
constant_alpha
=
True
,
constant_alpha
=
True
,
...
@@ -276,11 +275,11 @@ def make_image_loader(
...
@@ -276,11 +275,11 @@ def make_image_loader(
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
):
max_value
=
get_max_value
(
dtype
)
max_value
=
get_max_value
(
dtype
)
data
=
torch
.
testing
.
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
,
device
=
device
)
data
=
torch
.
testing
.
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
,
device
=
device
)
if
color_space
in
{
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB_ALPHA
}
and
constant_alpha
:
if
color_space
in
{
"
GRAY_ALPHA
"
,
"RGBA"
}
and
constant_alpha
:
data
[...,
-
1
,
:,
:]
=
max_value
data
[...,
-
1
,
:,
:]
=
max_value
return
datapoints
.
Image
(
data
,
color_space
=
color_space
)
return
datapoints
.
Image
(
data
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
)
make_image
=
from_loader
(
make_image_loader
)
make_image
=
from_loader
(
make_image_loader
)
...
@@ -290,10 +289,10 @@ def make_image_loaders(
...
@@ -290,10 +289,10 @@ def make_image_loaders(
*
,
*
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
color_spaces
=
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
"
GRAY
"
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
"
GRAY_ALPHA
"
,
datapoints
.
ColorSpace
.
RGB
,
"
RGB
"
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
"RGBA"
,
),
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
...
@@ -306,7 +305,7 @@ def make_image_loaders(
...
@@ -306,7 +305,7 @@ def make_image_loaders(
make_images
=
from_loaders
(
make_image_loaders
)
make_images
=
from_loaders
(
make_image_loaders
)
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
torch
.
uint8
):
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
"
RGB
"
,
dtype
=
torch
.
uint8
):
size
=
_parse_spatial_size
(
size
)
size
=
_parse_spatial_size
(
size
)
num_channels
=
get_num_channels
(
color_space
)
num_channels
=
get_num_channels
(
color_space
)
...
@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints
...
@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints
.
resize
((
width
,
height
))
.
resize
((
width
,
height
))
.
convert
(
.
convert
(
{
{
datapoints
.
ColorSpace
.
GRAY
:
"L"
,
"
GRAY
"
:
"L"
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
:
"LA"
,
"
GRAY_ALPHA
"
:
"LA"
,
datapoints
.
ColorSpace
.
RGB
:
"RGB"
,
"
RGB
"
:
"RGB"
,
datapoints
.
ColorSpace
.
RGB_ALPHA
:
"RGBA"
,
"RGBA"
:
"RGBA"
,
}[
color_space
]
}[
color_space
]
)
)
)
)
image_tensor
=
convert_dtype_image_tensor
(
to_image_tensor
(
image_pil
).
to
(
device
=
device
),
dtype
=
dtype
)
image_tensor
=
convert_dtype_image_tensor
(
to_image_tensor
(
image_pil
).
to
(
device
=
device
),
dtype
=
dtype
)
return
datapoints
.
Image
(
image_tensor
,
color_space
=
color_space
)
return
datapoints
.
Image
(
image_tensor
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
)
def
make_image_loaders_for_interpolation
(
def
make_image_loaders_for_interpolation
(
sizes
=
((
233
,
147
),),
sizes
=
((
233
,
147
),),
color_spaces
=
(
datapoints
.
ColorSpace
.
RGB
,),
color_spaces
=
(
"
RGB
"
,),
dtypes
=
(
torch
.
uint8
,),
dtypes
=
(
torch
.
uint8
,),
):
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
dtype
=
dtypes
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
dtype
=
dtypes
):
...
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
...
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
def
make_video_loader
(
def
make_video_loader
(
size
=
"random"
,
size
=
"random"
,
*
,
*
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
color_space
=
"
RGB
"
,
num_frames
=
"random"
,
num_frames
=
"random"
,
extra_dims
=
(),
extra_dims
=
(),
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
...
@@ -592,12 +591,10 @@ def make_video_loader(
...
@@ -592,12 +591,10 @@ def make_video_loader(
num_frames
=
int
(
torch
.
randint
(
1
,
5
,
()))
if
num_frames
==
"random"
else
num_frames
num_frames
=
int
(
torch
.
randint
(
1
,
5
,
()))
if
num_frames
==
"random"
else
num_frames
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
):
video
=
make_image
(
size
=
shape
[
-
2
:],
color_space
=
color_space
,
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
)
video
=
make_image
(
size
=
shape
[
-
2
:],
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
)
return
datapoints
.
Video
(
video
,
color_space
=
color_space
)
return
datapoints
.
Video
(
video
)
return
VideoLoader
(
return
VideoLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
)
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
,
color_space
=
color_space
)
make_video
=
from_loader
(
make_video_loader
)
make_video
=
from_loader
(
make_video_loader
)
...
@@ -607,8 +604,8 @@ def make_video_loaders(
...
@@ -607,8 +604,8 @@ def make_video_loaders(
*
,
*
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
sizes
=
DEFAULT_SPATIAL_SIZES
,
color_spaces
=
(
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
"
GRAY
"
,
datapoints
.
ColorSpace
.
RGB
,
"
RGB
"
,
),
),
num_frames
=
(
1
,
0
,
"random"
),
num_frames
=
(
1
,
0
,
"random"
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
...
...
test/prototype_transforms_kernel_infos.py
View file @
5dd95944
...
@@ -9,7 +9,6 @@ import pytest
...
@@ -9,7 +9,6 @@ import pytest
import
torch.testing
import
torch.testing
import
torchvision.ops
import
torchvision.ops
import
torchvision.prototype.transforms.functional
as
F
import
torchvision.prototype.transforms.functional
as
F
from
common_utils
import
cycle_over
from
datasets_utils
import
combinations_grid
from
datasets_utils
import
combinations_grid
from
prototype_common_utils
import
(
from
prototype_common_utils
import
(
ArgsKwargs
,
ArgsKwargs
,
...
@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size):
...
@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size):
def
sample_inputs_resize_image_tensor
():
def
sample_inputs_resize_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"RGB"
],
dtypes
=
[
torch
.
float32
]):
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
):
for
size
in
_get_resize_sizes
(
image_loader
.
spatial_size
):
for
size
in
_get_resize_sizes
(
image_loader
.
spatial_size
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
for
image_loader
,
interpolation
in
itertools
.
product
(
for
image_loader
,
interpolation
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]),
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
]),
[
[
F
.
InterpolationMode
.
NEAREST
,
F
.
InterpolationMode
.
NEAREST
,
F
.
InterpolationMode
.
BILINEAR
,
F
.
InterpolationMode
.
BILINEAR
,
...
@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
...
@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
def
sample_inputs_affine_image_tensor
():
def
sample_inputs_affine_image_tensor
():
make_affine_image_loaders
=
functools
.
partial
(
make_affine_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
)
for
image_loader
,
affine_params
in
itertools
.
product
(
make_affine_image_loaders
(),
_DIVERSE_AFFINE_PARAMS
):
for
image_loader
,
affine_params
in
itertools
.
product
(
make_affine_image_loaders
(),
_DIVERSE_AFFINE_PARAMS
):
...
@@ -684,69 +681,6 @@ KERNEL_INFOS.append(
...
@@ -684,69 +681,6 @@ KERNEL_INFOS.append(
)
)
def
sample_inputs_convert_color_space_image_tensor
():
color_spaces
=
sorted
(
set
(
datapoints
.
ColorSpace
)
-
{
datapoints
.
ColorSpace
.
OTHER
},
key
=
lambda
color_space
:
color_space
.
value
)
for
old_color_space
,
new_color_space
in
cycle_over
(
color_spaces
):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
old_color_space
],
constant_alpha
=
True
):
yield
ArgsKwargs
(
image_loader
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
for
color_space
in
color_spaces
:
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
color_space
],
dtypes
=
[
torch
.
float32
],
constant_alpha
=
True
):
yield
ArgsKwargs
(
image_loader
,
old_color_space
=
color_space
,
new_color_space
=
color_space
)
@
pil_reference_wrapper
def
reference_convert_color_space_image_tensor
(
image_pil
,
old_color_space
,
new_color_space
):
color_space_pil
=
datapoints
.
ColorSpace
.
from_pil_mode
(
image_pil
.
mode
)
if
color_space_pil
!=
old_color_space
:
raise
pytest
.
UsageError
(
f
"Converting the tensor image into an PIL image changed the colorspace "
f
"from
{
old_color_space
}
to
{
color_space_pil
}
"
)
return
F
.
convert_color_space_image_pil
(
image_pil
,
color_space
=
new_color_space
)
def
reference_inputs_convert_color_space_image_tensor
():
for
args_kwargs
in
sample_inputs_convert_color_space_image_tensor
():
(
image_loader
,
*
other_args
),
kwargs
=
args_kwargs
if
len
(
image_loader
.
shape
)
==
3
and
image_loader
.
dtype
==
torch
.
uint8
:
yield
args_kwargs
def
sample_inputs_convert_color_space_video
():
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
]
for
old_color_space
,
new_color_space
in
cycle_over
(
color_spaces
):
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
old_color_space
],
num_frames
=
[
"random"
]):
yield
ArgsKwargs
(
video_loader
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
convert_color_space_image_tensor
,
sample_inputs_fn
=
sample_inputs_convert_color_space_image_tensor
,
reference_fn
=
reference_convert_color_space_image_tensor
,
reference_inputs_fn
=
reference_inputs_convert_color_space_image_tensor
,
closeness_kwargs
=
{
**
pil_reference_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
},
),
KernelInfo
(
F
.
convert_color_space_video
,
sample_inputs_fn
=
sample_inputs_convert_color_space_video
,
),
]
)
def
sample_inputs_vertical_flip_image_tensor
():
def
sample_inputs_vertical_flip_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
dtypes
=
[
torch
.
float32
]):
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
dtypes
=
[
torch
.
float32
]):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
...
@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
...
@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
def
sample_inputs_rotate_image_tensor
():
def
sample_inputs_rotate_image_tensor
():
make_rotate_image_loaders
=
functools
.
partial
(
make_rotate_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
)
for
image_loader
in
make_rotate_image_loaders
():
for
image_loader
in
make_rotate_image_loaders
():
...
@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
...
@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
def
sample_inputs_crop_image_tensor
():
def
sample_inputs_crop_image_tensor
():
for
image_loader
,
params
in
itertools
.
product
(
for
image_loader
,
params
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
[
[
dict
(
top
=
4
,
left
=
3
,
height
=
7
,
width
=
8
),
dict
(
top
=
4
,
left
=
3
,
height
=
7
,
width
=
8
),
dict
(
top
=-
1
,
left
=
3
,
height
=
7
,
width
=
8
),
dict
(
top
=-
1
,
left
=
3
,
height
=
7
,
width
=
8
),
...
@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid(
...
@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid(
def
sample_inputs_pad_image_tensor
():
def
sample_inputs_pad_image_tensor
():
make_pad_image_loaders
=
functools
.
partial
(
make_pad_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]
make_image_loaders
,
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]
)
)
for
image_loader
,
padding
in
itertools
.
product
(
for
image_loader
,
padding
in
itertools
.
product
(
...
@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
...
@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def
sample_inputs_center_crop_image_tensor
():
def
sample_inputs_center_crop_image_tensor
():
for
image_loader
,
output_size
in
itertools
.
product
(
for
image_loader
,
output_size
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[(
16
,
17
)],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
[
[
# valid `output_size` types for which cropping is applied to both dimensions
# valid `output_size` types for which cropping is applied to both dimensions
*
[
5
,
(
4
,),
(
2
,
3
),
[
6
],
[
3
,
2
]],
*
[
5
,
(
4
,),
(
2
,
3
),
[
6
],
[
3
,
2
]],
...
@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend(
...
@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend(
def
sample_inputs_gaussian_blur_image_tensor
():
def
sample_inputs_gaussian_blur_image_tensor
():
make_gaussian_blur_image_loaders
=
functools
.
partial
(
make_gaussian_blur_image_loaders
=
functools
.
partial
(
make_image_loaders
,
sizes
=
[(
7
,
33
)],
color_spaces
=
[
"RGB"
])
make_image_loaders
,
sizes
=
[(
7
,
33
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]
)
for
image_loader
,
kernel_size
in
itertools
.
product
(
make_gaussian_blur_image_loaders
(),
[
5
,
(
3
,
3
),
[
3
,
3
]]):
for
image_loader
,
kernel_size
in
itertools
.
product
(
make_gaussian_blur_image_loaders
(),
[
5
,
(
3
,
3
),
[
3
,
3
]]):
yield
ArgsKwargs
(
image_loader
,
kernel_size
=
kernel_size
)
yield
ArgsKwargs
(
image_loader
,
kernel_size
=
kernel_size
)
...
@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend(
...
@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend(
def
sample_inputs_equalize_image_tensor
():
def
sample_inputs_equalize_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
...
@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor():
...
@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size
=
(
256
,
256
)
spatial_size
=
(
256
,
256
)
for
dtype
,
color_space
,
fn
in
itertools
.
product
(
for
dtype
,
color_space
,
fn
in
itertools
.
product
(
[
torch
.
uint8
],
[
torch
.
uint8
],
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
],
[
"
GRAY
"
,
"
RGB
"
],
[
[
lambda
shape
,
dtype
,
device
:
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
),
lambda
shape
,
dtype
,
device
:
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
),
lambda
shape
,
dtype
,
device
:
torch
.
full
(
lambda
shape
,
dtype
,
device
:
torch
.
full
(
...
@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor():
...
@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor():
],
],
],
],
):
):
image_loader
=
ImageLoader
(
image_loader
=
ImageLoader
(
fn
,
shape
=
(
get_num_channels
(
color_space
),
*
spatial_size
),
dtype
=
dtype
)
fn
,
shape
=
(
get_num_channels
(
color_space
),
*
spatial_size
),
dtype
=
dtype
,
color_space
=
color_space
)
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
...
@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend(
...
@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend(
def
sample_inputs_invert_image_tensor
():
def
sample_inputs_invert_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
def
reference_inputs_invert_image_tensor
():
def
reference_inputs_invert_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
...
@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8]
...
@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8]
def
sample_inputs_posterize_image_tensor
():
def
sample_inputs_posterize_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
bits
=
_POSTERIZE_BITS
[
0
])
yield
ArgsKwargs
(
image_loader
,
bits
=
_POSTERIZE_BITS
[
0
])
def
reference_inputs_posterize_image_tensor
():
def
reference_inputs_posterize_image_tensor
():
for
image_loader
,
bits
in
itertools
.
product
(
for
image_loader
,
bits
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_POSTERIZE_BITS
,
_POSTERIZE_BITS
,
):
):
yield
ArgsKwargs
(
image_loader
,
bits
=
bits
)
yield
ArgsKwargs
(
image_loader
,
bits
=
bits
)
...
@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype):
...
@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype):
def
sample_inputs_solarize_image_tensor
():
def
sample_inputs_solarize_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
threshold
=
next
(
_get_solarize_thresholds
(
image_loader
.
dtype
)))
yield
ArgsKwargs
(
image_loader
,
threshold
=
next
(
_get_solarize_thresholds
(
image_loader
.
dtype
)))
def
reference_inputs_solarize_image_tensor
():
def
reference_inputs_solarize_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
for
threshold
in
_get_solarize_thresholds
(
image_loader
.
dtype
):
for
threshold
in
_get_solarize_thresholds
(
image_loader
.
dtype
):
yield
ArgsKwargs
(
image_loader
,
threshold
=
threshold
)
yield
ArgsKwargs
(
image_loader
,
threshold
=
threshold
)
...
@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend(
...
@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend(
def
sample_inputs_autocontrast_image_tensor
():
def
sample_inputs_autocontrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
def
reference_inputs_autocontrast_image_tensor
():
def
reference_inputs_autocontrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]):
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
):
yield
ArgsKwargs
(
image_loader
)
yield
ArgsKwargs
(
image_loader
)
...
@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
...
@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_sharpness_image_tensor
():
def
sample_inputs_adjust_sharpness_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
,
(
2
,
2
)],
sizes
=
[
"random"
,
(
2
,
2
)],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
color_spaces
=
(
"
GRAY
"
,
"
RGB
"
),
):
):
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
_ADJUST_SHARPNESS_FACTORS
[
0
])
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
_ADJUST_SHARPNESS_FACTORS
[
0
])
def
reference_inputs_adjust_sharpness_image_tensor
():
def
reference_inputs_adjust_sharpness_image_tensor
():
for
image_loader
,
sharpness_factor
in
itertools
.
product
(
for
image_loader
,
sharpness_factor
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_SHARPNESS_FACTORS
,
_ADJUST_SHARPNESS_FACTORS
,
):
):
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
sharpness_factor
)
yield
ArgsKwargs
(
image_loader
,
sharpness_factor
=
sharpness_factor
)
...
@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
...
@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_brightness_image_tensor
():
def
sample_inputs_adjust_brightness_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
_ADJUST_BRIGHTNESS_FACTORS
[
0
])
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
_ADJUST_BRIGHTNESS_FACTORS
[
0
])
def
reference_inputs_adjust_brightness_image_tensor
():
def
reference_inputs_adjust_brightness_image_tensor
():
for
image_loader
,
brightness_factor
in
itertools
.
product
(
for
image_loader
,
brightness_factor
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_BRIGHTNESS_FACTORS
,
_ADJUST_BRIGHTNESS_FACTORS
,
):
):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
brightness_factor
)
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
brightness_factor
)
...
@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
...
@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_contrast_image_tensor
():
def
sample_inputs_adjust_contrast_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
_ADJUST_CONTRAST_FACTORS
[
0
])
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
_ADJUST_CONTRAST_FACTORS
[
0
])
def
reference_inputs_adjust_contrast_image_tensor
():
def
reference_inputs_adjust_contrast_image_tensor
():
for
image_loader
,
contrast_factor
in
itertools
.
product
(
for
image_loader
,
contrast_factor
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_CONTRAST_FACTORS
,
_ADJUST_CONTRAST_FACTORS
,
):
):
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
contrast_factor
)
yield
ArgsKwargs
(
image_loader
,
contrast_factor
=
contrast_factor
)
...
@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
...
@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def
sample_inputs_adjust_gamma_image_tensor
():
def
sample_inputs_adjust_gamma_image_tensor
():
gamma
,
gain
=
_ADJUST_GAMMA_GAMMAS_GAINS
[
0
]
gamma
,
gain
=
_ADJUST_GAMMA_GAMMAS_GAINS
[
0
]
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
def
reference_inputs_adjust_gamma_image_tensor
():
def
reference_inputs_adjust_gamma_image_tensor
():
for
image_loader
,
(
gamma
,
gain
)
in
itertools
.
product
(
for
image_loader
,
(
gamma
,
gain
)
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_GAMMA_GAMMAS_GAINS
,
_ADJUST_GAMMA_GAMMAS_GAINS
,
):
):
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
yield
ArgsKwargs
(
image_loader
,
gamma
=
gamma
,
gain
=
gain
)
...
@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
...
@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def
sample_inputs_adjust_hue_image_tensor
():
def
sample_inputs_adjust_hue_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
_ADJUST_HUE_FACTORS
[
0
])
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
_ADJUST_HUE_FACTORS
[
0
])
def
reference_inputs_adjust_hue_image_tensor
():
def
reference_inputs_adjust_hue_image_tensor
():
for
image_loader
,
hue_factor
in
itertools
.
product
(
for
image_loader
,
hue_factor
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_HUE_FACTORS
,
_ADJUST_HUE_FACTORS
,
):
):
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
hue_factor
)
yield
ArgsKwargs
(
image_loader
,
hue_factor
=
hue_factor
)
...
@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
...
@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def
sample_inputs_adjust_saturation_image_tensor
():
def
sample_inputs_adjust_saturation_image_tensor
():
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
sizes
=
[
"random"
],
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
)
):
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
_ADJUST_SATURATION_FACTORS
[
0
])
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
_ADJUST_SATURATION_FACTORS
[
0
])
def
reference_inputs_adjust_saturation_image_tensor
():
def
reference_inputs_adjust_saturation_image_tensor
():
for
image_loader
,
saturation_factor
in
itertools
.
product
(
for
image_loader
,
saturation_factor
in
itertools
.
product
(
make_image_loaders
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
color_spaces
=
(
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
RGB
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]
),
_ADJUST_SATURATION_FACTORS
,
_ADJUST_SATURATION_FACTORS
,
):
):
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
saturation_factor
)
yield
ArgsKwargs
(
image_loader
,
saturation_factor
=
saturation_factor
)
...
@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor():
...
@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor():
for
size
in
_FIVE_TEN_CROP_SIZES
:
for
size
in
_FIVE_TEN_CROP_SIZES
:
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
],
dtypes
=
[
torch
.
float32
],
):
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
yield
ArgsKwargs
(
image_loader
,
size
=
size
)
...
@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor():
...
@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor():
for
size
,
vertical_flip
in
itertools
.
product
(
_FIVE_TEN_CROP_SIZES
,
[
False
,
True
]):
for
size
,
vertical_flip
in
itertools
.
product
(
_FIVE_TEN_CROP_SIZES
,
[
False
,
True
]):
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
sizes
=
[
_get_five_ten_crop_spatial_size
(
size
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
],
dtypes
=
[
torch
.
float32
],
):
):
yield
ArgsKwargs
(
image_loader
,
size
=
size
,
vertical_flip
=
vertical_flip
)
yield
ArgsKwargs
(
image_loader
,
size
=
size
,
vertical_flip
=
vertical_flip
)
...
@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [
...
@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [
def
sample_inputs_normalize_image_tensor
():
def
sample_inputs_normalize_image_tensor
():
for
image_loader
,
(
mean
,
std
)
in
itertools
.
product
(
for
image_loader
,
(
mean
,
std
)
in
itertools
.
product
(
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
]),
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
]),
_NORMALIZE_MEANS_STDS
,
_NORMALIZE_MEANS_STDS
,
):
):
yield
ArgsKwargs
(
image_loader
,
mean
=
mean
,
std
=
std
)
yield
ArgsKwargs
(
image_loader
,
mean
=
mean
,
std
=
std
)
...
@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False):
...
@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False):
def
reference_inputs_normalize_image_tensor
():
def
reference_inputs_normalize_image_tensor
():
yield
ArgsKwargs
(
yield
ArgsKwargs
(
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
datapoints
.
ColorSpace
.
RGB
,
extra_dims
=
[
1
]),
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
"
RGB
"
,
extra_dims
=
[
1
]),
mean
=
[
0.5
,
0.5
,
0.5
],
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
],
std
=
[
1.0
,
1.0
,
1.0
],
)
)
...
@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor():
...
@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor():
def
sample_inputs_normalize_video
():
def
sample_inputs_normalize_video
():
mean
,
std
=
_NORMALIZE_MEANS_STDS
[
0
]
mean
,
std
=
_NORMALIZE_MEANS_STDS
[
0
]
for
video_loader
in
make_video_loaders
(
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
num_frames
=
[
"random"
],
dtypes
=
[
torch
.
float32
]
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
num_frames
=
[
"random"
],
dtypes
=
[
torch
.
float32
]
):
):
yield
ArgsKwargs
(
video_loader
,
mean
=
mean
,
std
=
std
)
yield
ArgsKwargs
(
video_loader
,
mean
=
mean
,
std
=
std
)
...
@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor():
...
@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor():
# conversion cannot be performed safely
# conversion cannot be performed safely
continue
continue
for
image_loader
in
make_image_loaders
(
for
image_loader
in
make_image_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"RGB"
],
dtypes
=
[
input_dtype
]):
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
input_dtype
]
):
yield
ArgsKwargs
(
image_loader
,
dtype
=
output_dtype
)
yield
ArgsKwargs
(
image_loader
,
dtype
=
output_dtype
)
...
@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
...
@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def
reference_inputs_uniform_temporal_subsample_video
():
def
reference_inputs_uniform_temporal_subsample_video
():
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
num_frames
=
[
10
]):
for
video_loader
in
make_video_loaders
(
sizes
=
[
"random"
],
color_spaces
=
[
"
RGB
"
],
num_frames
=
[
10
]):
for
num_samples
in
range
(
1
,
video_loader
.
shape
[
-
4
]
+
1
):
for
num_samples
in
range
(
1
,
video_loader
.
shape
[
-
4
]
+
1
):
yield
ArgsKwargs
(
video_loader
,
num_samples
)
yield
ArgsKwargs
(
video_loader
,
num_samples
)
...
...
test/test_prototype_transforms.py
View file @
5dd95944
...
@@ -161,8 +161,8 @@ class TestSmoke:
...
@@ -161,8 +161,8 @@ class TestSmoke:
itertools
.
chain
.
from_iterable
(
itertools
.
chain
.
from_iterable
(
fn
(
fn
(
color_spaces
=
[
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
"
GRAY
"
,
datapoints
.
ColorSpace
.
RGB
,
"
RGB
"
,
],
],
dtypes
=
[
torch
.
uint8
],
dtypes
=
[
torch
.
uint8
],
extra_dims
=
[(),
(
4
,)],
extra_dims
=
[(),
(
4
,)],
...
@@ -192,7 +192,7 @@ class TestSmoke:
...
@@ -192,7 +192,7 @@ class TestSmoke:
(
(
transforms
.
Normalize
(
mean
=
[
0.0
,
0.0
,
0.0
],
std
=
[
1.0
,
1.0
,
1.0
]),
transforms
.
Normalize
(
mean
=
[
0.0
,
0.0
,
0.0
],
std
=
[
1.0
,
1.0
,
1.0
]),
itertools
.
chain
.
from_iterable
(
itertools
.
chain
.
from_iterable
(
fn
(
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
dtypes
=
[
torch
.
float32
])
fn
(
color_spaces
=
[
"
RGB
"
],
dtypes
=
[
torch
.
float32
])
for
fn
in
[
for
fn
in
[
make_images
,
make_images
,
make_vanilla_tensor_images
,
make_vanilla_tensor_images
,
...
@@ -221,45 +221,6 @@ class TestSmoke:
...
@@ -221,45 +221,6 @@ class TestSmoke:
def
test_random_resized_crop
(
self
,
transform
,
input
):
def
test_random_resized_crop
(
self
,
transform
,
input
):
transform
(
input
)
transform
(
input
)
@
parametrize
(
[
(
transforms
.
ConvertColorSpace
(
color_space
=
new_color_space
,
old_color_space
=
old_color_space
),
itertools
.
chain
.
from_iterable
(
[
fn
(
color_spaces
=
[
old_color_space
])
for
fn
in
(
make_images
,
make_vanilla_tensor_images
,
make_pil_images
,
make_videos
,
)
]
),
)
for
old_color_space
,
new_color_space
in
itertools
.
product
(
[
datapoints
.
ColorSpace
.
GRAY
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
],
repeat
=
2
,
)
]
)
def
test_convert_color_space
(
self
,
transform
,
input
):
transform
(
input
)
def
test_convert_color_space_unsupported_types
(
self
):
transform
=
transforms
.
ConvertColorSpace
(
color_space
=
datapoints
.
ColorSpace
.
RGB
,
old_color_space
=
datapoints
.
ColorSpace
.
GRAY
)
for
inpt
in
[
make_bounding_box
(
format
=
"XYXY"
),
make_masks
()]:
output
=
transform
(
inpt
)
assert
output
is
inpt
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
class
TestRandomHorizontalFlip
:
class
TestRandomHorizontalFlip
:
...
@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop:
...
@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop:
transform
=
transforms
.
FixedSizeCrop
(
size
=
crop_size
)
transform
=
transforms
.
FixedSizeCrop
(
size
=
crop_size
)
flat_inputs
=
[
flat_inputs
=
[
make_image
(
size
=
spatial_size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
),
make_image
(
size
=
spatial_size
,
color_space
=
"
RGB
"
),
make_bounding_box
(
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
extra_dims
=
batch_shape
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
extra_dims
=
batch_shape
),
),
...
...
test/test_prototype_transforms_consistency.py
View file @
5dd95944
...
@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
...
@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from
torchvision.prototype.transforms.utils
import
query_spatial_size
from
torchvision.prototype.transforms.utils
import
query_spatial_size
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
],
extra_dims
=
[(
4
,)])
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"
RGB
"
],
extra_dims
=
[(
4
,)])
class
ConsistencyConfig
:
class
ConsistencyConfig
:
...
@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [
...
@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [
],
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs
=
dict
(
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
"RGB"
]),
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
]
),
supports_pil
=
False
,
supports_pil
=
False
,
),
),
ConsistencyConfig
(
ConsistencyConfig
(
...
@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [
...
@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
num_output_channels
=
1
),
ArgsKwargs
(
num_output_channels
=
1
),
ArgsKwargs
(
num_output_channels
=
3
),
ArgsKwargs
(
num_output_channels
=
3
),
],
],
make_images_kwargs
=
dict
(
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
color_spaces
=
[
"RGB"
,
"GRAY"
]),
DEFAULT_MAKE_IMAGES_KWARGS
,
color_spaces
=
[
datapoints
.
ColorSpace
.
RGB
,
datapoints
.
ColorSpace
.
GRAY
]
),
),
),
ConsistencyConfig
(
ConsistencyConfig
(
prototype_transforms
.
ConvertDtype
,
prototype_transforms
.
ConvertDtype
,
...
@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [
...
@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs
()],
[
ArgsKwargs
()],
make_images_kwargs
=
dict
(
make_images_kwargs
=
dict
(
color_spaces
=
[
color_spaces
=
[
datapoints
.
ColorSpace
.
GRAY
,
"
GRAY
"
,
datapoints
.
ColorSpace
.
GRAY_ALPHA
,
"
GRAY_ALPHA
"
,
datapoints
.
ColorSpace
.
RGB
,
"
RGB
"
,
datapoints
.
ColorSpace
.
RGB_ALPHA
,
"RGBA"
,
],
],
extra_dims
=
[()],
extra_dims
=
[()],
),
),
...
@@ -911,7 +907,7 @@ class TestRefDetTransforms:
...
@@ -911,7 +907,7 @@ class TestRefDetTransforms:
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
pil_image
=
to_image_pil
(
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
))
pil_image
=
to_image_pil
(
make_image
(
size
=
size
,
color_space
=
"
RGB
"
))
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -921,7 +917,7 @@ class TestRefDetTransforms:
...
@@ -921,7 +917,7 @@ class TestRefDetTransforms:
yield
(
pil_image
,
target
)
yield
(
pil_image
,
target
)
tensor_image
=
torch
.
Tensor
(
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
))
tensor_image
=
torch
.
Tensor
(
make_image
(
size
=
size
,
color_space
=
"
RGB
"
))
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -931,7 +927,7 @@ class TestRefDetTransforms:
...
@@ -931,7 +927,7 @@ class TestRefDetTransforms:
yield
(
tensor_image
,
target
)
yield
(
tensor_image
,
target
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
"
RGB
"
)
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
spatial_size
=
size
,
format
=
"XYXY"
,
extra_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -1015,7 +1011,7 @@ class TestRefSegTransforms:
...
@@ -1015,7 +1011,7 @@ class TestRefSegTransforms:
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
datapoints
.
ColorSpace
.
RGB
,
dtype
=
image_dtype
)
datapoint_image
=
make_image
(
size
=
size
,
color_space
=
"
RGB
"
,
dtype
=
image_dtype
)
datapoint_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
datapoint_mask
=
make_segmentation_mask
(
size
=
size
,
num_categories
=
num_categories
,
dtype
=
torch
.
uint8
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
...
...
test/test_prototype_transforms_functional.py
View file @
5dd95944
...
@@ -340,7 +340,6 @@ class TestDispatchers:
...
@@ -340,7 +340,6 @@ class TestDispatchers:
"dispatcher"
,
"dispatcher"
,
[
[
F
.
clamp_bounding_box
,
F
.
clamp_bounding_box
,
F
.
convert_color_space
,
F
.
get_dimensions
,
F
.
get_dimensions
,
F
.
get_image_num_channels
,
F
.
get_image_num_channels
,
F
.
get_image_size
,
F
.
get_image_size
,
...
...
test/test_prototype_transforms_utils.py
View file @
5dd95944
...
@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
...
@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from
torchvision.prototype.transforms.utils
import
has_all
,
has_any
from
torchvision.prototype.transforms.utils
import
has_all
,
has_any
IMAGE
=
make_image
(
color_space
=
datapoints
.
ColorSpace
.
RGB
)
IMAGE
=
make_image
(
color_space
=
"
RGB
"
)
BOUNDING_BOX
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
IMAGE
.
spatial_size
)
BOUNDING_BOX
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
IMAGE
.
spatial_size
)
MASK
=
make_detection_mask
(
size
=
IMAGE
.
spatial_size
)
MASK
=
make_detection_mask
(
size
=
IMAGE
.
spatial_size
)
...
...
torchvision/prototype/datapoints/__init__.py
View file @
5dd95944
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._datapoint
import
FillType
,
FillTypeJIT
,
InputType
,
InputTypeJIT
from
._datapoint
import
FillType
,
FillTypeJIT
,
InputType
,
InputTypeJIT
from
._image
import
ColorSpace
,
Image
,
ImageType
,
ImageTypeJIT
,
TensorImageType
,
TensorImageTypeJIT
from
._image
import
Image
,
ImageType
,
ImageTypeJIT
,
TensorImageType
,
TensorImageTypeJIT
from
._label
import
Label
,
OneHotLabel
from
._label
import
Label
,
OneHotLabel
from
._mask
import
Mask
from
._mask
import
Mask
from
._video
import
TensorVideoType
,
TensorVideoTypeJIT
,
Video
,
VideoType
,
VideoTypeJIT
from
._video
import
TensorVideoType
,
TensorVideoTypeJIT
,
Video
,
VideoType
,
VideoTypeJIT
torchvision/prototype/datapoints/_image.py
View file @
5dd95944
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
from
torchvision._utils
import
StrEnum
from
torchvision.transforms.functional
import
InterpolationMode
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
Datapoint
,
FillTypeJIT
from
._datapoint
import
Datapoint
,
FillTypeJIT
class
ColorSpace
(
StrEnum
):
OTHER
=
StrEnum
.
auto
()
GRAY
=
StrEnum
.
auto
()
GRAY_ALPHA
=
StrEnum
.
auto
()
RGB
=
StrEnum
.
auto
()
RGB_ALPHA
=
StrEnum
.
auto
()
@
classmethod
def
from_pil_mode
(
cls
,
mode
:
str
)
->
ColorSpace
:
if
mode
==
"L"
:
return
cls
.
GRAY
elif
mode
==
"LA"
:
return
cls
.
GRAY_ALPHA
elif
mode
==
"RGB"
:
return
cls
.
RGB
elif
mode
==
"RGBA"
:
return
cls
.
RGB_ALPHA
else
:
return
cls
.
OTHER
@
staticmethod
def
from_tensor_shape
(
shape
:
List
[
int
])
->
ColorSpace
:
return
_from_tensor_shape
(
shape
)
def
_from_tensor_shape
(
shape
:
List
[
int
])
->
ColorSpace
:
# Needed as a standalone method for JIT
ndim
=
len
(
shape
)
if
ndim
<
2
:
return
ColorSpace
.
OTHER
elif
ndim
==
2
:
return
ColorSpace
.
GRAY
num_channels
=
shape
[
-
3
]
if
num_channels
==
1
:
return
ColorSpace
.
GRAY
elif
num_channels
==
2
:
return
ColorSpace
.
GRAY_ALPHA
elif
num_channels
==
3
:
return
ColorSpace
.
RGB
elif
num_channels
==
4
:
return
ColorSpace
.
RGB_ALPHA
else
:
return
ColorSpace
.
OTHER
class
Image
(
Datapoint
):
class
Image
(
Datapoint
):
color_space
:
ColorSpace
@
classmethod
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
ColorSpace
)
->
Image
:
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
)
->
Image
:
image
=
tensor
.
as_subclass
(
cls
)
image
=
tensor
.
as_subclass
(
cls
)
image
.
color_space
=
color_space
return
image
return
image
def
__new__
(
def
__new__
(
cls
,
cls
,
data
:
Any
,
data
:
Any
,
*
,
*
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
requires_grad
:
bool
=
False
,
...
@@ -81,26 +29,14 @@ class Image(Datapoint):
...
@@ -81,26 +29,14 @@ class Image(Datapoint):
elif
tensor
.
ndim
==
2
:
elif
tensor
.
ndim
==
2
:
tensor
=
tensor
.
unsqueeze
(
0
)
tensor
=
tensor
.
unsqueeze
(
0
)
if
color_space
is
None
:
return
cls
.
_wrap
(
tensor
)
color_space
=
ColorSpace
.
from_tensor_shape
(
tensor
.
shape
)
# type: ignore[arg-type]
if
color_space
==
ColorSpace
.
OTHER
:
warnings
.
warn
(
"Unable to guess a specific color space. Consider passing it explicitly."
)
elif
isinstance
(
color_space
,
str
):
color_space
=
ColorSpace
.
from_str
(
color_space
.
upper
())
elif
not
isinstance
(
color_space
,
ColorSpace
):
raise
ValueError
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
)
@
classmethod
@
classmethod
def
wrap_like
(
cls
,
other
:
Image
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
Optional
[
ColorSpace
]
=
None
)
->
Image
:
def
wrap_like
(
cls
,
other
:
Image
,
tensor
:
torch
.
Tensor
)
->
Image
:
return
cls
.
_wrap
(
return
cls
.
_wrap
(
tensor
)
tensor
,
color_space
=
color_space
if
color_space
is
not
None
else
other
.
color_space
,
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
color_space
=
self
.
color_space
)
return
self
.
_make_repr
()
@
property
@
property
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
...
...
torchvision/prototype/datapoints/_video.py
View file @
5dd95944
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torchvision.transforms.functional
import
InterpolationMode
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
Datapoint
,
FillTypeJIT
from
._datapoint
import
Datapoint
,
FillTypeJIT
from
._image
import
ColorSpace
class
Video
(
Datapoint
):
class
Video
(
Datapoint
):
color_space
:
ColorSpace
@
classmethod
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
ColorSpace
)
->
Video
:
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
)
->
Video
:
video
=
tensor
.
as_subclass
(
cls
)
video
=
tensor
.
as_subclass
(
cls
)
video
.
color_space
=
color_space
return
video
return
video
def
__new__
(
def
__new__
(
cls
,
cls
,
data
:
Any
,
data
:
Any
,
*
,
*
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
requires_grad
:
bool
=
False
,
...
@@ -31,28 +25,14 @@ class Video(Datapoint):
...
@@ -31,28 +25,14 @@ class Video(Datapoint):
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
data
.
ndim
<
4
:
if
data
.
ndim
<
4
:
raise
ValueError
raise
ValueError
video
=
super
().
__new__
(
cls
,
data
,
requires_grad
=
requires_grad
)
return
cls
.
_wrap
(
tensor
)
if
color_space
is
None
:
color_space
=
ColorSpace
.
from_tensor_shape
(
video
.
shape
)
# type: ignore[arg-type]
if
color_space
==
ColorSpace
.
OTHER
:
warnings
.
warn
(
"Unable to guess a specific color space. Consider passing it explicitly."
)
elif
isinstance
(
color_space
,
str
):
color_space
=
ColorSpace
.
from_str
(
color_space
.
upper
())
elif
not
isinstance
(
color_space
,
ColorSpace
):
raise
ValueError
return
cls
.
_wrap
(
tensor
,
color_space
=
color_space
)
@
classmethod
@
classmethod
def
wrap_like
(
cls
,
other
:
Video
,
tensor
:
torch
.
Tensor
,
*
,
color_space
:
Optional
[
ColorSpace
]
=
None
)
->
Video
:
def
wrap_like
(
cls
,
other
:
Video
,
tensor
:
torch
.
Tensor
)
->
Video
:
return
cls
.
_wrap
(
return
cls
.
_wrap
(
tensor
)
tensor
,
color_space
=
color_space
if
color_space
is
not
None
else
other
.
color_space
,
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
color_space
=
self
.
color_space
)
return
self
.
_make_repr
()
@
property
@
property
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
def
spatial_size
(
self
)
->
Tuple
[
int
,
int
]:
...
...
torchvision/prototype/transforms/__init__.py
View file @
5dd95944
...
@@ -39,7 +39,7 @@ from ._geometry import (
...
@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter
,
ScaleJitter
,
TenCrop
,
TenCrop
,
)
)
from
._meta
import
ClampBoundingBoxes
,
ConvertBoundingBoxFormat
,
ConvertColorSpace
,
ConvertDtype
,
ConvertImageDtype
from
._meta
import
ClampBoundingBoxes
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._misc
import
(
from
._misc
import
(
GaussianBlur
,
GaussianBlur
,
Identity
,
Identity
,
...
...
torchvision/prototype/transforms/_deprecated.py
View file @
5dd95944
...
@@ -28,6 +28,7 @@ class ToTensor(Transform):
...
@@ -28,6 +28,7 @@ class ToTensor(Transform):
return
_F
.
to_tensor
(
inpt
)
return
_F
.
to_tensor
(
inpt
)
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class
Grayscale
(
Transform
):
class
Grayscale
(
Transform
):
_transformed_types
=
(
_transformed_types
=
(
datapoints
.
Image
,
datapoints
.
Image
,
...
@@ -62,7 +63,7 @@ class Grayscale(Transform):
...
@@ -62,7 +63,7 @@ class Grayscale(Transform):
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
self
.
num_output_channels
)
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
self
.
num_output_channels
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
,
color_space
=
datapoints
.
ColorSpace
.
GRAY
)
# type: ignore[arg-type]
output
=
inpt
.
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
return
output
...
@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform):
...
@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform):
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
output
=
_F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
,
color_space
=
datapoints
.
ColorSpace
.
GRAY
)
# type: ignore[arg-type]
output
=
inpt
.
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
return
output
torchvision/prototype/transforms/_meta.py
View file @
5dd95944
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
Union
import
PIL.Image
import
torch
import
torch
...
@@ -46,35 +44,6 @@ class ConvertDtype(Transform):
...
@@ -46,35 +44,6 @@ class ConvertDtype(Transform):
ConvertImageDtype
=
ConvertDtype
ConvertImageDtype
=
ConvertDtype
class
ConvertColorSpace
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
,
)
def
__init__
(
self
,
color_space
:
Union
[
str
,
datapoints
.
ColorSpace
],
old_color_space
:
Optional
[
Union
[
str
,
datapoints
.
ColorSpace
]]
=
None
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
color_space
,
str
):
color_space
=
datapoints
.
ColorSpace
.
from_str
(
color_space
)
self
.
color_space
=
color_space
if
isinstance
(
old_color_space
,
str
):
old_color_space
=
datapoints
.
ColorSpace
.
from_str
(
old_color_space
)
self
.
old_color_space
=
old_color_space
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
return
F
.
convert_color_space
(
inpt
,
color_space
=
self
.
color_space
,
old_color_space
=
self
.
old_color_space
)
class
ClampBoundingBoxes
(
Transform
):
class
ClampBoundingBoxes
(
Transform
):
_transformed_types
=
(
datapoints
.
BoundingBox
,)
_transformed_types
=
(
datapoints
.
BoundingBox
,)
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
5dd95944
...
@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip
...
@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_box
,
clamp_bounding_box
,
convert_format_bounding_box
,
convert_format_bounding_box
,
convert_color_space_image_tensor
,
convert_color_space_image_pil
,
convert_color_space_video
,
convert_color_space
,
convert_dtype_image_tensor
,
convert_dtype_image_tensor
,
convert_dtype
,
convert_dtype
,
convert_dtype_video
,
convert_dtype_video
,
...
...
torchvision/prototype/transforms/functional/_deprecated.py
View file @
5dd95944
...
@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
...
@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
old_color_space
=
None
# TODO: remove when un-deprecating
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
if
not
(
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
))
and
isinstance
(
else
:
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)
old_color_space
=
None
):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
replacement
=
(
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
5dd95944
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.datapoints
import
BoundingBoxFormat
,
ColorSpace
from
torchvision.prototype.datapoints
import
BoundingBoxFormat
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.transforms.functional_tensor
import
_max_value
...
@@ -225,29 +225,6 @@ def clamp_bounding_box(
...
@@ -225,29 +225,6 @@ def clamp_bounding_box(
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
def
_strip_alpha
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
if
not
torch
.
all
(
alpha
==
_max_value
(
alpha
.
dtype
)):
raise
RuntimeError
(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return
image
def
_add_alpha
(
image
:
torch
.
Tensor
,
alpha
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
alpha
is
None
:
shape
=
list
(
image
.
shape
)
shape
[
-
3
]
=
1
alpha
=
torch
.
full
(
shape
,
_max_value
(
image
.
dtype
),
dtype
=
image
.
dtype
,
device
=
image
.
device
)
return
torch
.
cat
((
image
,
alpha
),
dim
=-
3
)
def
_gray_to_rgb
(
grayscale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeats
=
[
1
]
*
grayscale
.
ndim
repeats
[
-
3
]
=
3
return
grayscale
.
repeat
(
repeats
)
def
_rgb_to_gray
(
image
:
torch
.
Tensor
,
cast
:
bool
=
True
)
->
torch
.
Tensor
:
def
_rgb_to_gray
(
image
:
torch
.
Tensor
,
cast
:
bool
=
True
)
->
torch
.
Tensor
:
r
,
g
,
b
=
image
.
unbind
(
dim
=-
3
)
r
,
g
,
b
=
image
.
unbind
(
dim
=-
3
)
l_img
=
r
.
mul
(
0.2989
).
add_
(
g
,
alpha
=
0.587
).
add_
(
b
,
alpha
=
0.114
)
l_img
=
r
.
mul
(
0.2989
).
add_
(
g
,
alpha
=
0.587
).
add_
(
b
,
alpha
=
0.114
)
...
@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
...
@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
return
l_img
return
l_img
def
convert_color_space_image_tensor
(
image
:
torch
.
Tensor
,
old_color_space
:
ColorSpace
,
new_color_space
:
ColorSpace
)
->
torch
.
Tensor
:
if
new_color_space
==
old_color_space
:
return
image
if
old_color_space
==
ColorSpace
.
OTHER
or
new_color_space
==
ColorSpace
.
OTHER
:
raise
RuntimeError
(
f
"Conversion to or from
{
ColorSpace
.
OTHER
}
is not supported."
)
if
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
return
_add_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
RGB
:
return
_gray_to_rgb
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
return
_add_alpha
(
_gray_to_rgb
(
image
))
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_strip_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
RGB
:
return
_gray_to_rgb
(
_strip_alpha
(
image
))
elif
old_color_space
==
ColorSpace
.
GRAY_ALPHA
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
return
_add_alpha
(
_gray_to_rgb
(
image
),
alpha
)
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_rgb_to_gray
(
image
)
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
return
_add_alpha
(
_rgb_to_gray
(
image
))
elif
old_color_space
==
ColorSpace
.
RGB
and
new_color_space
==
ColorSpace
.
RGB_ALPHA
:
return
_add_alpha
(
image
)
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY
:
return
_rgb_to_gray
(
_strip_alpha
(
image
))
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
GRAY_ALPHA
:
image
,
alpha
=
torch
.
tensor_split
(
image
,
indices
=
(
-
1
,),
dim
=-
3
)
return
_add_alpha
(
_rgb_to_gray
(
image
),
alpha
)
elif
old_color_space
==
ColorSpace
.
RGB_ALPHA
and
new_color_space
==
ColorSpace
.
RGB
:
return
_strip_alpha
(
image
)
else
:
raise
RuntimeError
(
f
"Conversion from
{
old_color_space
}
to
{
new_color_space
}
is not supported."
)
_COLOR_SPACE_TO_PIL_MODE
=
{
ColorSpace
.
GRAY
:
"L"
,
ColorSpace
.
GRAY_ALPHA
:
"LA"
,
ColorSpace
.
RGB
:
"RGB"
,
ColorSpace
.
RGB_ALPHA
:
"RGBA"
,
}
@
torch
.
jit
.
unused
def
convert_color_space_image_pil
(
image
:
PIL
.
Image
.
Image
,
color_space
:
ColorSpace
)
->
PIL
.
Image
.
Image
:
old_mode
=
image
.
mode
try
:
new_mode
=
_COLOR_SPACE_TO_PIL_MODE
[
color_space
]
except
KeyError
:
raise
ValueError
(
f
"Conversion from
{
ColorSpace
.
from_pil_mode
(
old_mode
)
}
to
{
color_space
}
is not supported."
)
if
image
.
mode
==
new_mode
:
return
image
return
image
.
convert
(
new_mode
)
def
convert_color_space_video
(
video
:
torch
.
Tensor
,
old_color_space
:
ColorSpace
,
new_color_space
:
ColorSpace
)
->
torch
.
Tensor
:
return
convert_color_space_image_tensor
(
video
,
old_color_space
=
old_color_space
,
new_color_space
=
new_color_space
)
def
convert_color_space
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
color_space
:
ColorSpace
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
old_color_space
is
None
:
raise
RuntimeError
(
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return
convert_color_space_image_tensor
(
inpt
,
old_color_space
=
old_color_space
,
new_color_space
=
color_space
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
convert_color_space_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_color_space
=
inpt
.
color_space
,
new_color_space
=
color_space
)
return
datapoints
.
Image
.
wrap_like
(
inpt
,
output
,
color_space
=
color_space
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
convert_color_space_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_color_space
=
inpt
.
color_space
,
new_color_space
=
color_space
)
return
datapoints
.
Video
.
wrap_like
(
inpt
,
output
,
color_space
=
color_space
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
convert_color_space_image_pil
(
inpt
,
color_space
=
color_space
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
uint8
:
if
dtype
==
torch
.
uint8
:
return
8
return
8
...
...
torchvision/transforms/functional.py
View file @
5dd95944
...
@@ -1234,6 +1234,9 @@ def affine(
...
@@ -1234,6 +1234,9 @@ def affine(
return
F_t
.
affine
(
img
,
matrix
=
matrix
,
interpolation
=
interpolation
.
value
,
fill
=
fill
)
return
F_t
.
affine
(
img
,
matrix
=
matrix
,
interpolation
=
interpolation
.
value
,
fill
=
fill
)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig. Anyway, this can be deprecated as we migrate to V2.
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
to_grayscale
(
img
,
num_output_channels
=
1
):
def
to_grayscale
(
img
,
num_output_channels
=
1
):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment