Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2bababf2
Unverified
Commit
2bababf2
authored
Mar 15, 2024
by
ahmadsharif1
Committed by
GitHub
Mar 15, 2024
Browse files
Add a GrayscaleToRgb transform that can expand channels to 3 (#8247)
parent
fa82fd3b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
94 additions
and
1 deletion
+94
-1
docs/source/transforms.rst
docs/source/transforms.rst
+3
-1
test/test_transforms_v2.py
test/test_transforms_v2.py
+48
-0
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-0
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+14
-0
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+2
-0
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+26
-0
No files found.
docs/source/transforms.rst
View file @
2bababf2
...
@@ -347,6 +347,7 @@ Color
...
@@ -347,6 +347,7 @@ Color
v2.RandomChannelPermutation
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
v2.RandomPhotometricDistort
v2.Grayscale
v2.Grayscale
v2.RGB
v2.RandomGrayscale
v2.RandomGrayscale
v2.GaussianBlur
v2.GaussianBlur
v2.RandomInvert
v2.RandomInvert
...
@@ -364,6 +365,7 @@ Functionals
...
@@ -364,6 +365,7 @@ Functionals
v2.functional.permute_channels
v2.functional.permute_channels
v2.functional.rgb_to_grayscale
v2.functional.rgb_to_grayscale
v2.functional.grayscale_to_rgb
v2.functional.to_grayscale
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.gaussian_blur
v2.functional.invert
v2.functional.invert
...
@@ -584,7 +586,7 @@ Conversion
...
@@ -584,7 +586,7 @@ Conversion
while performing the conversion, while some may not do any scaling. By
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
.. autosummary::
.. autosummary::
:toctree: generated/
:toctree: generated/
:template: class.rst
:template: class.rst
...
...
test/test_transforms_v2.py
View file @
2bababf2
...
@@ -5005,6 +5005,54 @@ class TestRgbToGrayscale:
...
@@ -5005,6 +5005,54 @@ class TestRgbToGrayscale:
assert_equal
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
assert_equal
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
class
TestGrayscaleToRgb
:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
uint8
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image
(
self
,
dtype
,
device
):
check_kernel
(
F
.
grayscale_to_rgb_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
])
def
test_functional
(
self
,
make_input
):
check_functional
(
F
.
grayscale_to_rgb
,
make_input
())
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
rgb_to_grayscale_image
,
torch
.
Tensor
),
(
F
.
_rgb_to_grayscale_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
rgb_to_grayscale_image
,
tv_tensors
.
Image
),
],
)
def
test_functional_signature
(
self
,
kernel
,
input_type
):
check_functional_kernel_signature_match
(
F
.
grayscale_to_rgb
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
])
def
test_transform
(
self
,
make_input
):
check_transform
(
transforms
.
RGB
(),
make_input
(
color_space
=
"GRAY"
))
@
pytest
.
mark
.
parametrize
(
"fn"
,
[
F
.
grayscale_to_rgb
,
transform_cls_to_functional
(
transforms
.
RGB
)])
def
test_image_correctness
(
self
,
fn
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
,
color_space
=
"GRAY"
)
actual
=
fn
(
image
)
expected
=
F
.
to_image
(
F
.
grayscale_to_rgb
(
F
.
to_pil_image
(
image
)))
assert_equal
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
def
test_expanded_channels_are_not_views_into_the_same_underlying_tensor
(
self
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
,
color_space
=
"GRAY"
)
output_image
=
F
.
grayscale_to_rgb
(
image
)
assert_equal
(
output_image
[
0
][
0
][
0
],
output_image
[
1
][
0
][
0
])
output_image
[
0
][
0
][
0
]
=
output_image
[
0
][
0
][
0
]
+
1
assert
output_image
[
0
][
0
][
0
]
!=
output_image
[
1
][
0
][
0
]
def
test_rgb_image_is_unchanged
(
self
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
,
color_space
=
"RGB"
)
assert_equal
(
image
.
shape
[
-
3
],
3
)
assert_equal
(
F
.
grayscale_to_rgb
(
image
),
image
)
class
TestRandomZoomOut
:
class
TestRandomZoomOut
:
# Tests are light because this largely relies on the already tested `pad` kernels.
# Tests are light because this largely relies on the already tested `pad` kernels.
...
...
torchvision/transforms/v2/__init__.py
View file @
2bababf2
...
@@ -18,6 +18,7 @@ from ._color import (
...
@@ -18,6 +18,7 @@ from ._color import (
RandomPhotometricDistort
,
RandomPhotometricDistort
,
RandomPosterize
,
RandomPosterize
,
RandomSolarize
,
RandomSolarize
,
RGB
,
)
)
from
._container
import
Compose
,
RandomApply
,
RandomChoice
,
RandomOrder
from
._container
import
Compose
,
RandomApply
,
RandomChoice
,
RandomOrder
from
._geometry
import
(
from
._geometry
import
(
...
...
torchvision/transforms/v2/_color.py
View file @
2bababf2
...
@@ -54,6 +54,20 @@ class RandomGrayscale(_RandomApplyTransform):
...
@@ -54,6 +54,20 @@ class RandomGrayscale(_RandomApplyTransform):
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
class
RGB
(
Transform
):
"""Convert images or videos to RGB (if they are already not RGB).
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
"""
def
__init__
(
self
):
super
().
__init__
()
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
_call_kernel
(
F
.
grayscale_to_rgb
,
inpt
)
class
ColorJitter
(
Transform
):
class
ColorJitter
(
Transform
):
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
2bababf2
...
@@ -63,6 +63,8 @@ from ._color import (
...
@@ -63,6 +63,8 @@ from ._color import (
equalize
,
equalize
,
equalize_image
,
equalize_image
,
equalize_video
,
equalize_video
,
grayscale_to_rgb
,
grayscale_to_rgb_image
,
invert
,
invert
,
invert_image
,
invert_image
,
invert_video
,
invert_video
,
...
...
torchvision/transforms/v2/functional/_color.py
View file @
2bababf2
...
@@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
...
@@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
return
_FP
.
to_grayscale
(
image
,
num_output_channels
=
num_output_channels
)
return
_FP
.
to_grayscale
(
image
,
num_output_channels
=
num_output_channels
)
def
grayscale_to_rgb
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details."""
if
torch
.
jit
.
is_scripting
():
return
grayscale_to_rgb_image
(
inpt
)
_log_api_usage_once
(
grayscale_to_rgb
)
kernel
=
_get_kernel
(
grayscale_to_rgb
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
grayscale_to_rgb
,
torch
.
Tensor
)
@
_register_kernel_internal
(
grayscale_to_rgb
,
tv_tensors
.
Image
)
def
grayscale_to_rgb_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
shape
[
-
3
]
>=
3
:
# Image already has RGB channels. We don't need to do anything.
return
image
# rgb_to_grayscale can be used to add channels so we reuse that function.
return
_rgb_to_grayscale_image
(
image
,
num_output_channels
=
3
,
preserve_dtype
=
True
)
@
_register_kernel_internal
(
grayscale_to_rgb
,
PIL
.
Image
.
Image
)
def
grayscale_to_rgb_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
PIL
.
Image
.
Image
:
return
image
.
convert
(
mode
=
"RGB"
)
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
ratio
=
float
(
ratio
)
ratio
=
float
(
ratio
)
fp
=
image1
.
is_floating_point
()
fp
=
image1
.
is_floating_point
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment