Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b16dec19
Unverified
Commit
b16dec19
authored
Oct 13, 2022
by
vfdev
Committed by
GitHub
Oct 13, 2022
Browse files
[proto] Performance improvements for equalize op (#6757)
* [proto] Performance improvements for equalize op * Added tests
parent
54a2d4e8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
26 deletions
+44
-26
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+11
-0
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+33
-26
No files found.
test/test_prototype_transforms_functional.py
View file @
b16dec19
...
...
@@ -1037,3 +1037,14 @@ def test_to_image_pil(inpt, mode):
assert
isinstance
(
output
,
PIL
.
Image
.
Image
)
assert
np
.
asarray
(
inpt
).
sum
()
==
np
.
asarray
(
output
).
sum
()
def
test_equalize_image_tensor_edge_cases
():
inpt
=
torch
.
zeros
(
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
output
=
F
.
equalize_image_tensor
(
inpt
)
torch
.
testing
.
assert_close
(
inpt
,
output
)
inpt
=
torch
.
zeros
(
5
,
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
inpt
[...,
100
:,
100
:]
=
1
output
=
F
.
equalize_image_tensor
(
inpt
)
assert
output
.
unique
().
tolist
()
==
[
0
,
255
]
torchvision/prototype/transforms/functional/_color.py
View file @
b16dec19
...
...
@@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return
autocontrast_image_pil
(
inpt
)
def
_scale_channel
(
img_chan
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if
img_chan
.
is_cuda
:
hist
=
torch
.
histc
(
img_chan
.
to
(
torch
.
float32
),
bins
=
256
,
min
=
0
,
max
=
255
)
else
:
hist
=
torch
.
bincount
(
img_chan
.
view
(
-
1
),
minlength
=
256
)
nonzero_hist
=
hist
[
hist
!=
0
]
step
=
torch
.
div
(
nonzero_hist
[:
-
1
].
sum
(),
255
,
rounding_mode
=
"floor"
)
if
step
==
0
:
return
img_chan
lut
=
torch
.
div
(
torch
.
cumsum
(
hist
,
0
)
+
torch
.
div
(
step
,
2
,
rounding_mode
=
"floor"
),
step
,
rounding_mode
=
"floor"
)
# Doing inplace clamp and converting lut to uint8 improves perfs
lut
.
clamp_
(
0
,
255
)
lut
=
lut
.
to
(
torch
.
uint8
)
lut
=
torch
.
nn
.
functional
.
pad
(
lut
[:
-
1
],
[
1
,
0
])
return
lut
[
img_chan
.
to
(
torch
.
int64
)]
def
_equalize_image_tensor_vec
(
img
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# input img shape should be [N, H, W]
shape
=
img
.
shape
# Compute image histogram:
flat_img
=
img
.
flatten
(
start_dim
=
1
).
to
(
torch
.
long
)
# -> [N, H * W]
hist
=
flat_img
.
new_zeros
(
shape
[
0
],
256
)
hist
.
scatter_add_
(
dim
=
1
,
index
=
flat_img
,
src
=
flat_img
.
new_ones
(
1
).
expand_as
(
flat_img
))
# Compute image cdf
chist
=
hist
.
cumsum_
(
dim
=
1
)
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
idx
=
chist
.
argmax
(
dim
=
1
).
sub_
(
1
)
# If histogram is degenerate (hist of zero image), index is -1
neg_idx_mask
=
idx
<
0
idx
.
clamp_
(
min
=
0
)
step
=
chist
.
gather
(
dim
=
1
,
index
=
idx
.
unsqueeze
(
1
))
step
[
neg_idx_mask
]
=
0
step
.
div_
(
255
,
rounding_mode
=
"floor"
)
# Compute batched Look-up-table:
# Necessary to avoid an integer division by zero, which raises
clamped_step
=
step
.
clamp
(
min
=
1
)
chist
.
add_
(
torch
.
div
(
step
,
2
,
rounding_mode
=
"floor"
)).
div_
(
clamped_step
,
rounding_mode
=
"floor"
).
clamp_
(
0
,
255
)
lut
=
chist
.
to
(
torch
.
uint8
)
# [N, 256]
# Pad lut with zeros
zeros
=
lut
.
new_zeros
((
1
,
1
)).
expand
(
shape
[
0
],
1
)
lut
=
torch
.
cat
([
zeros
,
lut
[:,
:
-
1
]],
dim
=
1
)
return
torch
.
where
((
step
==
0
).
unsqueeze
(
-
1
),
img
,
lut
.
gather
(
dim
=
1
,
index
=
flat_img
).
view_as
(
img
))
def
equalize_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if
image
.
numel
()
==
0
:
return
image
elif
image
.
ndim
==
2
:
return
_scale_channel
(
image
)
else
:
return
torch
.
stack
([
_scale_channel
(
x
)
for
x
in
image
.
view
(
-
1
,
height
,
width
)]).
view
(
image
.
shape
)
return
_equalize_image_tensor_vec
(
image
.
view
(
-
1
,
height
,
width
)).
view
(
image
.
shape
)
equalize_image_pil
=
_FP
.
equalize
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment