Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
067b9dca
Unverified
Commit
067b9dca
authored
Feb 15, 2021
by
vfdev
Committed by
GitHub
Feb 15, 2021
Browse files
Functional to_tensor returns float tensor of default dtype (#3398)
Fixes #3393
parent
f04e9cb9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
3 deletions
+19
-3
test/test_transforms.py
test/test_transforms.py
+14
-0
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+5
-3
No files found.
test/test_transforms.py
View file @
067b9dca
...
...
@@ -620,6 +620,20 @@ class Tester(unittest.TestCase):
output
=
trans
(
img
)
self
.
assertTrue
(
np
.
allclose
(
input_data
.
numpy
(),
output
.
numpy
()))
def
test_to_tensor_with_other_default_dtypes
(
self
):
current_def_dtype
=
torch
.
get_default_dtype
()
t
=
transforms
.
ToTensor
()
np_arr
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
img
=
Image
.
fromarray
(
np_arr
)
for
dtype
in
[
torch
.
float16
,
torch
.
float
,
torch
.
double
]:
torch
.
set_default_dtype
(
dtype
)
res
=
t
(
img
)
self
.
assertTrue
(
res
.
dtype
==
dtype
,
msg
=
f
"
{
res
.
dtype
}
vs
{
dtype
}
"
)
torch
.
set_default_dtype
(
current_def_dtype
)
def
test_max_value
(
self
):
for
dtype
in
int_dtypes
():
self
.
assertEqual
(
F_t
.
_max_value
(
dtype
),
torch
.
iinfo
(
dtype
).
max
)
...
...
torchvision/transforms/functional.py
View file @
067b9dca
...
...
@@ -104,6 +104,8 @@ def to_tensor(pic):
if
_is_numpy
(
pic
)
and
not
_is_numpy_image
(
pic
):
raise
ValueError
(
'pic should be 2/3 dimensional. Got {} dimensions.'
.
format
(
pic
.
ndim
))
default_float_dtype
=
torch
.
get_default_dtype
()
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
if
pic
.
ndim
==
2
:
...
...
@@ -112,12 +114,12 @@ def to_tensor(pic):
img
=
torch
.
from_numpy
(
pic
.
transpose
((
2
,
0
,
1
))).
contiguous
()
# backward compatibility
if
isinstance
(
img
,
torch
.
ByteTensor
):
return
img
.
float
(
).
div
(
255
)
return
img
.
to
(
dtype
=
default_float_dtype
).
div
(
255
)
else
:
return
img
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
nppic
=
np
.
zeros
([
pic
.
channels
,
pic
.
height
,
pic
.
width
],
dtype
=
np
.
float32
)
nppic
=
np
.
zeros
([
pic
.
channels
,
pic
.
height
,
pic
.
width
],
dtype
=
default_float_dtype
)
pic
.
copyto
(
nppic
)
return
torch
.
from_numpy
(
nppic
)
...
...
@@ -137,7 +139,7 @@ def to_tensor(pic):
# put it from HWC to CHW format
img
=
img
.
permute
((
2
,
0
,
1
)).
contiguous
()
if
isinstance
(
img
,
torch
.
ByteTensor
):
return
img
.
float
(
).
div
(
255
)
return
img
.
to
(
dtype
=
default_float_dtype
).
div
(
255
)
else
:
return
img
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment