Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7a7ab7e7
Unverified
Commit
7a7ab7e7
authored
Nov 08, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Nov 08, 2022
Browse files
[prototype] Speed up `adjust_sharpness_image_tensor` (#6930)
* Speed up `adjust_sharpness_image_tensor` * Add a comment
parent
bf58902b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
1 deletion
+27
-1
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+27
-1
No files found.
torchvision/prototype/transforms/functional/_color.py
View file @
7a7ab7e7
import
torch
import
torch
from
torch.nn.functional
import
conv2d
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.transforms
import
functional_pil
as
_FP
,
functional_tensor
as
_FT
from
torchvision.transforms
import
functional_pil
as
_FP
,
functional_tensor
as
_FT
...
@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
...
@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if
image
.
numel
()
==
0
or
height
<=
2
or
width
<=
2
:
if
image
.
numel
()
==
0
or
height
<=
2
or
width
<=
2
:
return
image
return
image
bound
=
_FT
.
_max_value
(
image
.
dtype
)
fp
=
image
.
is_floating_point
()
shape
=
image
.
shape
shape
=
image
.
shape
if
image
.
ndim
>
4
:
if
image
.
ndim
>
4
:
...
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
...
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else
:
else
:
needs_unsquash
=
False
needs_unsquash
=
False
output
=
_blend
(
image
,
_FT
.
_blurred_degenerate_image
(
image
),
sharpness_factor
)
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
kernel_dtype
=
image
.
dtype
if
fp
else
torch
.
float32
a
,
b
=
1.0
/
13.0
,
5.0
/
13.0
kernel
=
torch
.
tensor
([[
a
,
a
,
a
],
[
a
,
b
,
a
],
[
a
,
a
,
a
]],
dtype
=
kernel_dtype
,
device
=
image
.
device
)
kernel
=
kernel
.
expand
(
num_channels
,
1
,
3
,
3
)
# We copy and cast at the same time to avoid modifications on the original data
output
=
image
.
to
(
dtype
=
kernel_dtype
,
copy
=
True
)
blurred_degenerate
=
conv2d
(
output
,
kernel
,
groups
=
num_channels
)
if
not
fp
:
# it is better to round before cast
blurred_degenerate
=
blurred_degenerate
.
round_
()
# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
view
=
output
[...,
1
:
-
1
,
1
:
-
1
]
# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
view
.
add_
(
blurred_degenerate
.
sub_
(
view
),
alpha
=
(
1.0
-
sharpness_factor
))
# The actual data of ouput have been modified by the above. We only need to clamp and cast now.
output
=
output
.
clamp_
(
0
,
bound
)
if
not
fp
:
output
=
output
.
to
(
image
.
dtype
)
if
needs_unsquash
:
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
output
=
output
.
reshape
(
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment