Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9fa8000d
"...text-generation-inference.git" did not exist on "c2d4a3b5c7bb6a8367c00f7c797bf87f4b2fcef9"
Unverified
Commit
9fa8000d
authored
Jan 28, 2022
by
Nicolas Hug
Committed by
GitHub
Jan 28, 2022
Browse files
Add support for flow batches in flow_to_image (#5308)
parent
8e874ff8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
26 deletions
+48
-26
test/test_utils.py
test/test_utils.py
+25
-12
torchvision/utils.py
torchvision/utils.py
+23
-14
No files found.
test/test_utils.py
View file @
9fa8000d
...
...
@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
utils
.
draw_keypoints
(
image
=
img
,
keypoints
=
invalid_keypoints
)
def
test_flow_to_image
():
@
pytest
.
mark
.
parametrize
(
"batch"
,
(
True
,
False
))
def
test_flow_to_image
(
batch
):
h
,
w
=
100
,
100
flow
=
torch
.
meshgrid
(
torch
.
arange
(
h
),
torch
.
arange
(
w
),
indexing
=
"ij"
)
flow
=
torch
.
stack
(
flow
[::
-
1
],
dim
=
0
).
float
()
flow
[
0
]
-=
h
/
2
flow
[
1
]
-=
w
/
2
if
batch
:
flow
=
torch
.
stack
([
flow
,
flow
])
img
=
utils
.
flow_to_image
(
flow
)
assert
img
.
shape
==
(
2
,
3
,
h
,
w
)
if
batch
else
(
3
,
h
,
w
)
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
,
"expected_flow.pt"
)
expected_img
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
assert_equal
(
expected_img
,
img
)
if
batch
:
expected_img
=
torch
.
stack
([
expected_img
,
expected_img
])
assert_equal
(
expected_img
,
img
)
def
test_flow_to_image_errors
():
wrong_flow1
=
torch
.
full
((
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
)
wrong_flow2
=
torch
.
full
((
2
,
10
),
0
,
dtype
=
torch
.
float
)
wrong_flow3
=
torch
.
full
((
2
,
10
,
30
),
0
,
dtype
=
torch
.
int
)
with
pytest
.
raises
(
ValueError
,
match
=
"Input flow should have shape"
):
utils
.
flow_to_image
(
flow
=
wrong_flow1
)
with
pytest
.
raises
(
ValueError
,
match
=
"Input flow should have shape"
):
utils
.
flow_to_image
(
flow
=
wrong_flow2
)
with
pytest
.
raises
(
ValueError
,
match
=
"Flow should be of dtype torch.float"
):
utils
.
flow_to_image
(
flow
=
wrong_flow3
)
@
pytest
.
mark
.
parametrize
(
"input_flow, match"
,
(
(
torch
.
full
((
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
5
,
3
,
10
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
2
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
5
,
2
,
10
),
0
,
dtype
=
torch
.
float
),
"Input flow should have shape"
),
(
torch
.
full
((
2
,
10
,
30
),
0
,
dtype
=
torch
.
int
),
"Flow should be of dtype torch.float"
),
),
)
def
test_flow_to_image_errors
(
input_flow
,
match
):
with
pytest
.
raises
(
ValueError
,
match
=
match
):
utils
.
flow_to_image
(
flow
=
input_flow
)
if
__name__
==
"__main__"
:
...
...
torchvision/utils.py
View file @
9fa8000d
...
...
@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
flow (Tensor): Flow of shape
(N, 2, H, W) or
(2, H, W) and dtype torch.float.
Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if
flow
.
dtype
!=
torch
.
float
:
raise
ValueError
(
f
"Flow should be of dtype torch.float, got
{
flow
.
dtype
}
."
)
if
flow
.
ndim
!=
3
or
flow
.
size
(
0
)
!=
2
:
raise
ValueError
(
f
"Input flow should have shape (2, H, W), got
{
flow
.
shape
}
."
)
orig_shape
=
flow
.
shape
if
flow
.
ndim
==
3
:
flow
=
flow
[
None
]
# Add batch dim
max_norm
=
torch
.
sum
(
flow
**
2
,
dim
=
0
).
sqrt
().
max
()
if
flow
.
ndim
!=
4
or
flow
.
shape
[
1
]
!=
2
:
raise
ValueError
(
f
"Input flow should have shape (2, H, W) or (N, 2, H, W), got
{
orig_shape
}
."
)
max_norm
=
torch
.
sum
(
flow
**
2
,
dim
=
1
).
sqrt
().
max
()
epsilon
=
torch
.
finfo
((
flow
).
dtype
).
eps
normalized_flow
=
flow
/
(
max_norm
+
epsilon
)
return
_normalized_flow_to_image
(
normalized_flow
)
img
=
_normalized_flow_to_image
(
normalized_flow
)
if
len
(
orig_shape
)
==
3
:
img
=
img
[
0
]
# Remove batch dim
return
img
@
torch
.
no_grad
()
def
_normalized_flow_to_image
(
normalized_flow
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a normalized flow to an RGB image.
Converts a
batch of
normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
normalized_flow (torch.Tensor): Normalized flow tensor of shape (
N,
2, H, W)
Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
img (Tensor(
N,
3, H, W)): Flow visualization image of dtype uint8.
"""
_
,
H
,
W
=
normalized_flow
.
shape
flow_image
=
torch
.
zeros
((
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
N
,
_
,
H
,
W
=
normalized_flow
.
shape
flow_image
=
torch
.
zeros
((
N
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
colorwheel
=
_make_colorwheel
()
# shape [55x3]
num_cols
=
colorwheel
.
shape
[
0
]
norm
=
torch
.
sum
(
normalized_flow
**
2
,
dim
=
0
).
sqrt
()
a
=
torch
.
atan2
(
-
normalized_flow
[
1
],
-
normalized_flow
[
0
])
/
torch
.
pi
norm
=
torch
.
sum
(
normalized_flow
**
2
,
dim
=
1
).
sqrt
()
a
=
torch
.
atan2
(
-
normalized_flow
[
:,
1
,
:,
:
],
-
normalized_flow
[
:,
0
,
:,
:
])
/
torch
.
pi
fk
=
(
a
+
1
)
/
2
*
(
num_cols
-
1
)
k0
=
torch
.
floor
(
fk
).
to
(
torch
.
long
)
k1
=
k0
+
1
...
...
@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
col1
=
tmp
[
k1
]
/
255.0
col
=
(
1
-
f
)
*
col0
+
f
*
col1
col
=
1
-
norm
*
(
1
-
col
)
flow_image
[
c
,
:,
:]
=
torch
.
floor
(
255
*
col
)
flow_image
[
:,
c
,
:,
:]
=
torch
.
floor
(
255
*
col
)
return
flow_image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment