Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5d8d61ac
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4742e12c2360bd2b43aedcf6d11cefc3a048f791"
Unverified
Commit
5d8d61ac
authored
Aug 09, 2023
by
Philip Meier
Committed by
GitHub
Aug 09, 2023
Browse files
add PermuteChannels transform (#7624)
parent
2ab937a0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
151 additions
and
18 deletions
+151
-18
docs/source/transforms.rst
docs/source/transforms.rst
+1
-0
test/test_transforms_v2.py
test/test_transforms_v2.py
+1
-0
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+58
-0
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-0
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+22
-17
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+4
-0
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+64
-1
No files found.
docs/source/transforms.rst
View file @
5d8d61ac
...
@@ -155,6 +155,7 @@ Color
...
@@ -155,6 +155,7 @@ Color
ColorJitter
ColorJitter
v2.ColorJitter
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
v2.RandomPhotometricDistort
Grayscale
Grayscale
v2.Grayscale
v2.Grayscale
...
...
test/test_transforms_v2.py
View file @
5d8d61ac
...
@@ -124,6 +124,7 @@ class TestSmoke:
...
@@ -124,6 +124,7 @@ class TestSmoke:
(
transforms
.
RandomEqualize
(
p
=
1.0
),
None
),
(
transforms
.
RandomEqualize
(
p
=
1.0
),
None
),
(
transforms
.
RandomGrayscale
(
p
=
1.0
),
None
),
(
transforms
.
RandomGrayscale
(
p
=
1.0
),
None
),
(
transforms
.
RandomInvert
(
p
=
1.0
),
None
),
(
transforms
.
RandomInvert
(
p
=
1.0
),
None
),
(
transforms
.
RandomChannelPermutation
(),
None
),
(
transforms
.
RandomPhotometricDistort
(
p
=
1.0
),
None
),
(
transforms
.
RandomPhotometricDistort
(
p
=
1.0
),
None
),
(
transforms
.
RandomPosterize
(
bits
=
4
,
p
=
1.0
),
None
),
(
transforms
.
RandomPosterize
(
bits
=
4
,
p
=
1.0
),
None
),
(
transforms
.
RandomSolarize
(
threshold
=
0.5
,
p
=
1.0
),
None
),
(
transforms
.
RandomSolarize
(
threshold
=
0.5
,
p
=
1.0
),
None
),
...
...
test/test_transforms_v2_refactored.py
View file @
5d8d61ac
...
@@ -2280,3 +2280,61 @@ class TestGetKernel:
...
@@ -2280,3 +2280,61 @@ class TestGetKernel:
_register_kernel_internal
(
F
.
resize
,
MyDatapoint
,
datapoint_wrapper
=
False
)(
resize_my_datapoint
)
_register_kernel_internal
(
F
.
resize
,
MyDatapoint
,
datapoint_wrapper
=
False
)(
resize_my_datapoint
)
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
resize_my_datapoint
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
resize_my_datapoint
class
TestPermuteChannels
:
_DEFAULT_PERMUTATION
=
[
2
,
0
,
1
]
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
permute_channels_image_tensor
,
make_image_tensor
),
# FIXME
# check_kernel does not support PIL kernel, but it should
(
F
.
permute_channels_image_tensor
,
make_image
),
(
F
.
permute_channels_video
,
make_video
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel
(
self
,
kernel
,
make_input
,
dtype
,
device
):
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
permute_channels_image_tensor
,
make_image_tensor
),
(
F
.
permute_channels_image_pil
,
make_image_pil
),
(
F
.
permute_channels_image_tensor
,
make_image
),
(
F
.
permute_channels_video
,
make_video
),
],
)
def
test_dispatcher
(
self
,
kernel
,
make_input
):
check_dispatcher
(
F
.
permute_channels
,
kernel
,
make_input
(),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
permute_channels_image_tensor
,
torch
.
Tensor
),
(
F
.
permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
permute_channels_image_tensor
,
datapoints
.
Image
),
(
F
.
permute_channels_video
,
datapoints
.
Video
),
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_kernel_signature_match
(
F
.
permute_channels
,
kernel
=
kernel
,
input_type
=
input_type
)
def
reference_image_correctness
(
self
,
image
,
permutation
):
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
permuted_channel_images
=
[
channel_images
[
channel_idx
]
for
channel_idx
in
permutation
]
return
datapoints
.
Image
(
torch
.
concat
(
permuted_channel_images
,
dim
=-
3
))
@
pytest
.
mark
.
parametrize
(
"permutation"
,
[[
2
,
0
,
1
],
[
1
,
2
,
0
],
[
2
,
0
,
1
],
[
0
,
1
,
2
]])
@
pytest
.
mark
.
parametrize
(
"batch_dims"
,
[(),
(
2
,),
(
2
,
1
)])
def
test_image_correctness
(
self
,
permutation
,
batch_dims
):
image
=
make_image
(
batch_dims
=
batch_dims
)
actual
=
F
.
permute_channels
(
image
,
permutation
=
permutation
)
expected
=
self
.
reference_image_correctness
(
image
,
permutation
=
permutation
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torchvision/transforms/v2/__init__.py
View file @
5d8d61ac
...
@@ -11,6 +11,7 @@ from ._color import (
...
@@ -11,6 +11,7 @@ from ._color import (
Grayscale
,
Grayscale
,
RandomAdjustSharpness
,
RandomAdjustSharpness
,
RandomAutocontrast
,
RandomAutocontrast
,
RandomChannelPermutation
,
RandomEqualize
,
RandomEqualize
,
RandomGrayscale
,
RandomGrayscale
,
RandomInvert
,
RandomInvert
,
...
...
torchvision/transforms/v2/_color.py
View file @
5d8d61ac
...
@@ -177,7 +177,27 @@ class ColorJitter(Transform):
...
@@ -177,7 +177,27 @@ class ColorJitter(Transform):
return
output
return
output
# TODO: This class seems to be untested
class
RandomChannelPermutation
(
Transform
):
"""[BETA] Randomly permute the channels of an image or video
.. v2betastatus:: RandomChannelPermutation transform
"""
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
num_channels
,
*
_
=
query_chw
(
flat_inputs
)
return
dict
(
permutation
=
torch
.
randperm
(
num_channels
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
permute_channels
(
inpt
,
params
[
"permutation"
])
class
RandomPhotometricDistort
(
Transform
):
class
RandomPhotometricDistort
(
Transform
):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
...
@@ -241,21 +261,6 @@ class RandomPhotometricDistort(Transform):
...
@@ -241,21 +261,6 @@ class RandomPhotometricDistort(Transform):
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
return
params
return
params
def
_permute_channels
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
permutation
:
torch
.
Tensor
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
orig_inpt
=
inpt
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
inpt
=
F
.
pil_to_tensor
(
inpt
)
# TODO: Find a better fix than as_subclass???
output
=
inpt
[...,
permutation
,
:,
:].
as_subclass
(
type
(
inpt
))
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
output
=
F
.
to_image_pil
(
output
)
return
output
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
...
@@ -270,7 +275,7 @@ class RandomPhotometricDistort(Transform):
...
@@ -270,7 +275,7 @@ class RandomPhotometricDistort(Transform):
if
params
[
"contrast_factor"
]
is
not
None
and
not
params
[
"contrast_before"
]:
if
params
[
"contrast_factor"
]
is
not
None
and
not
params
[
"contrast_before"
]:
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
if
params
[
"channel_permutation"
]
is
not
None
:
if
params
[
"channel_permutation"
]
is
not
None
:
inpt
=
self
.
_
permute_channels
(
inpt
,
permutation
=
params
[
"channel_permutation"
])
inpt
=
F
.
permute_channels
(
inpt
,
permutation
=
params
[
"channel_permutation"
])
return
inpt
return
inpt
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
5d8d61ac
...
@@ -62,6 +62,10 @@ from ._color import (
...
@@ -62,6 +62,10 @@ from ._color import (
invert_image_pil
,
invert_image_pil
,
invert_image_tensor
,
invert_image_tensor
,
invert_video
,
invert_video
,
permute_channels
,
permute_channels_image_pil
,
permute_channels_image_tensor
,
permute_channels_video
,
posterize
,
posterize
,
posterize_image_pil
,
posterize_image_pil
,
posterize_image_tensor
,
posterize_image_tensor
,
...
...
torchvision/transforms/v2/functional/_color.py
View file @
5d8d61ac
from
typing
import
Union
from
typing
import
List
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -10,6 +10,8 @@ from torchvision.transforms._functional_tensor import _max_value
...
@@ -10,6 +10,8 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._type_conversion
import
pil_to_tensor
,
to_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
...
@@ -641,3 +643,64 @@ invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert
...
@@ -641,3 +643,64 @@ invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert
@
_register_kernel_internal
(
invert
,
datapoints
.
Video
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Video
)
def
invert_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
invert_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
invert_image_tensor
(
video
)
return
invert_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
datapoints
.
_InputTypeJIT
,
permutation
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
>>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0])
Args:
permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the
channel index in the input and the value determines the channel index in the output. For example,
``permutation=[2, 0 , 1]``
- takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``,
- takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and
- takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``.
Raises:
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
"""
if
torch
.
jit
.
is_scripting
():
return
permute_channels_image_tensor
(
inpt
,
permutation
=
permutation
)
_log_api_usage_once
(
permute_channels
)
kernel
=
_get_kernel
(
permute_channels
,
type
(
inpt
))
return
kernel
(
inpt
,
permutation
=
permutation
)
@
_register_kernel_internal
(
permute_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
permute_channels
,
datapoints
.
Image
)
def
permute_channels_image_tensor
(
image
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
if
len
(
permutation
)
!=
num_channels
:
raise
ValueError
(
f
"Length of permutation does not match number of channels: "
f
"
{
len
(
permutation
)
}
!=
{
num_channels
}
"
)
if
image
.
numel
()
==
0
:
return
image
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
image
=
image
[:,
permutation
,
:,
:]
return
image
.
reshape
(
shape
)
@
_register_kernel_internal
(
permute_channels
,
PIL
.
Image
.
Image
)
def
permute_channels_image_pil
(
image
:
PIL
.
Image
.
Image
,
permutation
:
List
[
int
])
->
PIL
.
Image
:
return
to_image_pil
(
permute_channels_image_tensor
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
@
_register_kernel_internal
(
permute_channels
,
datapoints
.
Video
)
def
permute_channels_video
(
video
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
return
permute_channels_image_tensor
(
video
,
permutation
=
permutation
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment