Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6c44ceb5
Unverified
Commit
6c44ceb5
authored
Aug 15, 2023
by
Nicolas Granger
Committed by
GitHub
Aug 15, 2023
Browse files
Replace stack/mask/reduce by indexing in _hsv2rgb (#7754)
Co-authored-by:
vfdev
<
vfdev.5@gmail.com
>
parent
f244e27e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+12
-6
No files found.
torchvision/transforms/v2/functional/_color.py
View file @
6c44ceb5
...
@@ -317,14 +317,20 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
...
@@ -317,14 +317,20 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
p
=
one_minus_s
.
mul_
(
v
).
clamp_
(
0.0
,
1.0
)
p
=
one_minus_s
.
mul_
(
v
).
clamp_
(
0.0
,
1.0
)
i
.
remainder_
(
6
)
i
.
remainder_
(
6
)
mask
=
i
.
unsqueeze
(
dim
=-
3
)
==
torch
.
arange
(
6
,
device
=
i
.
device
).
view
(
-
1
,
1
,
1
)
vpqt
=
torch
.
stack
((
v
,
p
,
q
,
t
),
dim
=-
3
)
a1
=
torch
.
stack
((
v
,
q
,
p
,
p
,
t
,
v
),
dim
=-
3
)
# vpqt -> rgb mapping based on i
a2
=
torch
.
stack
((
t
,
v
,
v
,
q
,
p
,
p
),
dim
=-
3
)
select
=
torch
.
tensor
([[
0
,
2
,
1
,
1
,
3
,
0
],
[
3
,
0
,
0
,
2
,
1
,
1
],
[
1
,
1
,
3
,
0
,
0
,
2
]],
dtype
=
torch
.
long
)
a3
=
torch
.
stack
((
p
,
p
,
t
,
v
,
v
,
q
),
dim
=-
3
)
select
=
select
.
to
(
device
=
img
.
device
,
non_blocking
=
True
)
a4
=
torch
.
stack
((
a1
,
a2
,
a3
),
dim
=-
4
)
return
(
a4
.
mul_
(
mask
.
unsqueeze
(
dim
=-
4
))).
sum
(
dim
=-
3
)
select
=
select
[:,
i
]
if
select
.
ndim
>
3
:
# if input.shape is (B, ..., C, H, W) then
# select.shape is (C, B, ..., H, W)
# thus we move C axis to get (B, ..., C, H, W)
select
=
select
.
moveaxis
(
0
,
-
3
)
return
vpqt
.
gather
(
-
3
,
select
)
@
_register_kernel_internal
(
adjust_hue
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_hue
,
torch
.
Tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment