Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9fa8000d
Unverified
Commit
9fa8000d
authored
Jan 28, 2022
by
Nicolas Hug
Committed by
GitHub
Jan 28, 2022
Browse files
Add support for flow batches in flow_to_image (#5308)
parent
8e874ff8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
26 deletions
+48
-26
test/test_utils.py
test/test_utils.py
+25
-12
torchvision/utils.py
torchvision/utils.py
+23
-14
No files found.
test/test_utils.py
View file @
9fa8000d
...
...
@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
utils
.
draw_keypoints
(
image
=
img
,
keypoints
=
invalid_keypoints
)
def
test_flow_to_image
():
@
pytest
.
mark
.
parametrize
(
"batch"
,
(
True
,
False
))
def
test_flow_to_image
(
batch
):
h
,
w
=
100
,
100
flow
=
torch
.
meshgrid
(
torch
.
arange
(
h
),
torch
.
arange
(
w
),
indexing
=
"ij"
)
flow
=
torch
.
stack
(
flow
[::
-
1
],
dim
=
0
).
float
()
flow
[
0
]
-=
h
/
2
flow
[
1
]
-=
w
/
2
if
batch
:
flow
=
torch
.
stack
([
flow
,
flow
])
img
=
utils
.
flow_to_image
(
flow
)
assert
img
.
shape
==
(
2
,
3
,
h
,
w
)
if
batch
else
(
3
,
h
,
w
)
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
,
"expected_flow.pt"
)
expected_img
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
assert_equal
(
expected_img
,
img
)
if
batch
:
expected_img
=
torch
.
stack
([
expected_img
,
expected_img
])
assert_equal
(
expected_img
,
img
)
def
test_flow_to_image_errors
():
wrong_flow1
=
torch
.
full
((
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
)
wrong_flow2
=
torch
.
full
((
2
,
10
),
0
,
dtype
=
torch
.
float
)
wrong_flow3
=
torch
.
full
((
2
,
10
,
30
),
0
,
dtype
=
torch
.
int
)
with
pytest
.
raises
(
ValueError
,
match
=
"Input flow should have shape"
):
utils
.
flow_to_image
(
flow
=
wrong_flow1
)
with
pytest
.
raises
(
ValueError
,
match
=
"Input flow should have shape"
):
utils
.
flow_to_image
(
flow
=
wrong_flow2
)
with
pytest
.
raises
(
ValueError
,
match
=
"Flow should be of dtype torch.float"
):
utils
.
flow_to_image
(
flow
=
wrong_flow3
)
@
pytest
.
mark
.
parametrize
(
"input_flow, match"
,
(
(
torch
.
full
((
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
5
,
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
2
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
5
,
2
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
2
,
10
,
30
),
0
,
dtype
=
torch
.
int
),
"Flow should be of dtype torch.float"
),
),
)
def
test_flow_to_image_errors
(
input_flow
,
match
):
with
pytest
.
raises
(
ValueError
,
match
=
match
):
utils
.
flow_to_image
(
flow
=
input_flow
)
if
__name__
==
"__main__"
:
...
...
torchvision/utils.py
View file @
9fa8000d
...
...
@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
flow (Tensor): Flow of shape
(N, 2, H, W) or
(2, H, W) and dtype torch.float.
Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if
flow
.
dtype
!=
torch
.
float
:
raise
ValueError
(
f
"Flow should be of dtype torch.float, got
{
flow
.
dtype
}
."
)
if
flow
.
ndim
!=
3
or
flow
.
size
(
0
)
!=
2
:
raise
ValueError
(
f
"Input flow should have shape (2, H, W), got
{
flow
.
shape
}
."
)
orig_shape
=
flow
.
shape
if
flow
.
ndim
==
3
:
flow
=
flow
[
None
]
# Add batch dim
max_norm
=
torch
.
sum
(
flow
**
2
,
dim
=
0
).
sqrt
().
max
()
if
flow
.
ndim
!=
4
or
flow
.
shape
[
1
]
!=
2
:
raise
ValueError
(
f
"Input flow should have shape (2, H, W) or (N, 2, H, W), got
{
orig_shape
}
."
)
max_norm
=
torch
.
sum
(
flow
**
2
,
dim
=
1
).
sqrt
().
max
()
epsilon
=
torch
.
finfo
((
flow
).
dtype
).
eps
normalized_flow
=
flow
/
(
max_norm
+
epsilon
)
return
_normalized_flow_to_image
(
normalized_flow
)
img
=
_normalized_flow_to_image
(
normalized_flow
)
if
len
(
orig_shape
)
==
3
:
img
=
img
[
0
]
# Remove batch dim
return
img
@
torch
.
no_grad
()
def
_normalized_flow_to_image
(
normalized_flow
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a normalized flow to an RGB image.
Converts a
batch of
normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
normalized_flow (torch.Tensor): Normalized flow tensor of shape (
N,
2, H, W)
Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
img (Tensor(
N,
3, H, W)): Flow visualization image of dtype uint8.
"""
_
,
H
,
W
=
normalized_flow
.
shape
flow_image
=
torch
.
zeros
((
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
N
,
_
,
H
,
W
=
normalized_flow
.
shape
flow_image
=
torch
.
zeros
((
N
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
colorwheel
=
_make_colorwheel
()
# shape [55x3]
num_cols
=
colorwheel
.
shape
[
0
]
norm
=
torch
.
sum
(
normalized_flow
**
2
,
dim
=
0
).
sqrt
()
a
=
torch
.
atan2
(
-
normalized_flow
[
1
],
-
normalized_flow
[
0
])
/
torch
.
pi
norm
=
torch
.
sum
(
normalized_flow
**
2
,
dim
=
1
).
sqrt
()
a
=
torch
.
atan2
(
-
normalized_flow
[
:,
1
,
:,
:
],
-
normalized_flow
[
:,
0
,
:,
:
])
/
torch
.
pi
fk
=
(
a
+
1
)
/
2
*
(
num_cols
-
1
)
k0
=
torch
.
floor
(
fk
).
to
(
torch
.
long
)
k1
=
k0
+
1
...
...
@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
col1
=
tmp
[
k1
]
/
255.0
col
=
(
1
-
f
)
*
col0
+
f
*
col1
col
=
1
-
norm
*
(
1
-
col
)
flow_image
[
c
,
:,
:]
=
torch
.
floor
(
255
*
col
)
flow_image
[
:,
c
,
:,
:]
=
torch
.
floor
(
255
*
col
)
return
flow_image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment