Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7b18556c
Commit
7b18556c
authored
Sep 04, 2017
by
Francisco Massa
Committed by
Soumith Chintala
Sep 04, 2017
Browse files
Add asserts to make_grid and avoid inplace modification (#241)
parent
8e375670
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
6 deletions
+37
-6
test/test_utils.py
test/test_utils.py
+31
-0
torchvision/utils.py
torchvision/utils.py
+6
-6
No files found.
test/test_utils.py
0 → 100644
View file @
7b18556c
import
torch
import
torchvision.utils
as
utils
import
unittest
class
Tester
(
unittest
.
TestCase
):
def
test_make_grid_not_inplace
(
self
):
t
=
torch
.
rand
(
5
,
3
,
10
,
10
)
t_clone
=
t
.
clone
()
utils
.
make_grid
(
t
,
normalize
=
False
)
assert
torch
.
equal
(
t
,
t_clone
),
'make_grid modified tensor in-place'
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
False
)
assert
torch
.
equal
(
t
,
t_clone
),
'make_grid modified tensor in-place'
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
True
)
assert
torch
.
equal
(
t
,
t_clone
),
'make_grid modified tensor in-place'
def
test_make_grid_raises_with_variable
(
self
):
t
=
torch
.
autograd
.
Variable
(
torch
.
rand
(
3
,
10
,
10
))
with
self
.
assertRaises
(
TypeError
):
utils
.
make_grid
(
t
)
with
self
.
assertRaises
(
TypeError
):
utils
.
make_grid
([
t
,
t
,
t
,
t
])
if
__name__
==
'__main__'
:
unittest
.
main
()
torchvision/utils.py
View file @
7b18556c
...
...
@@ -26,14 +26,13 @@ def make_grid(tensor, nrow=8, padding=2,
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
"""
if
not
(
torch
.
is_tensor
(
tensor
)
or
(
isinstance
(
tensor
,
list
)
and
all
(
torch
.
is_tensor
(
t
)
for
t
in
tensor
))):
raise
TypeError
(
'tensor or list of tensors expected, got {}'
.
format
(
type
(
tensor
)))
# if list of tensors, convert to a 4D mini-batch Tensor
if
isinstance
(
tensor
,
list
):
tensorlist
=
tensor
numImages
=
len
(
tensorlist
)
size
=
torch
.
Size
(
torch
.
Size
([
numImages
])
+
tensorlist
[
0
].
size
())
tensor
=
tensorlist
[
0
].
new
(
size
)
for
i
in
irange
(
numImages
):
tensor
[
i
].
copy_
(
tensorlist
[
i
])
tensor
=
torch
.
stack
(
tensor
,
dim
=
0
)
if
tensor
.
dim
()
==
2
:
# single image H x W
tensor
=
tensor
.
view
(
1
,
tensor
.
size
(
0
),
tensor
.
size
(
1
))
...
...
@@ -45,6 +44,7 @@ def make_grid(tensor, nrow=8, padding=2,
tensor
=
torch
.
cat
((
tensor
,
tensor
,
tensor
),
1
)
if
normalize
is
True
:
tensor
=
tensor
.
clone
()
# avoid modifying tensor in-place
if
range
is
not
None
:
assert
isinstance
(
range
,
tuple
),
\
"range has to be a tuple (min, max) if specified. min and max are numbers"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment