Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
991bad2f
Commit
991bad2f
authored
Mar 23, 2017
by
Bodo Kaiser
Committed by
Soumith Chintala
Mar 23, 2017
Browse files
updated ToTensor to support more types
parent
6cbb22bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
13 deletions
+46
-13
test/test_transforms.py
test/test_transforms.py
+24
-0
torchvision/transforms.py
torchvision/transforms.py
+22
-13
No files found.
test/test_transforms.py
View file @
991bad2f
...
...
@@ -151,6 +151,30 @@ class Tester(unittest.TestCase):
expected_output
=
img_data
.
mul
(
255
).
int
().
float
().
div
(
255
)
assert
np
.
allclose
(
expected_output
[
0
].
numpy
(),
to_tensor
(
l
).
numpy
())
def
test_tensor_gray_to_pil_image
(
self
):
trans
=
transforms
.
ToPILImage
()
to_tensor
=
transforms
.
ToTensor
()
img_data_byte
=
torch
.
ByteTensor
(
1
,
4
,
4
).
random_
(
0
,
255
)
img_data_short
=
torch
.
ShortTensor
(
1
,
4
,
4
).
random_
()
img_data_int
=
torch
.
IntTensor
(
1
,
4
,
4
).
random_
()
img_data_float
=
torch
.
FloatTensor
(
1
,
4
,
4
).
uniform_
()
img_byte
=
trans
(
img_data_byte
)
img_short
=
trans
(
img_data_short
)
img_int
=
trans
(
img_data_int
)
img_float
=
trans
(
img_data_float
)
assert
img_byte
.
mode
==
'L'
assert
img_short
.
mode
==
'I;16'
assert
img_int
.
mode
==
'I'
#assert img_float.mode == 'F'
assert
np
.
allclose
(
img_data_short
.
numpy
(),
to_tensor
(
img_short
).
numpy
())
assert
np
.
allclose
(
img_data_int
.
numpy
(),
to_tensor
(
img_int
).
numpy
())
# would cause breaking changes as ToTensor converts to range [0, 1]
#assert np.allclose(img_data_byte.numpy(), to_tensor(img_byte).numpy())
#assert np.allclose(img_data_float.numpy(), to_tensor(img_float).numpy())
def
test_ndarray_to_pil_image
(
self
):
trans
=
transforms
.
ToPILImage
()
img_data
=
torch
.
ByteTensor
(
4
,
4
,
3
).
random_
(
0
,
255
).
numpy
()
...
...
torchvision/transforms.py
View file @
991bad2f
...
...
@@ -39,19 +39,30 @@ class ToTensor(object):
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
img
=
torch
.
from_numpy
(
pic
.
transpose
((
2
,
0
,
1
)))
else
:
# backard compability
return
img
.
float
().
div
(
255
)
# handle PIL Image
if
pic
.
mode
==
'I'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int32
))
elif
pic
.
mode
==
'I;16'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int16
))
else
:
img
=
torch
.
ByteTensor
(
torch
.
ByteStorage
.
from_buffer
(
pic
.
tobytes
()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if
pic
.
mode
==
'YCbCr'
:
nchannel
=
3
elif
pic
.
mode
==
'I;16'
:
nchannel
=
1
else
:
nchannel
=
len
(
pic
.
mode
)
img
=
img
.
view
(
pic
.
size
[
1
],
pic
.
size
[
0
],
nchannel
)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img
=
img
.
transpose
(
0
,
1
).
transpose
(
0
,
2
).
contiguous
()
if
isinstance
(
img
,
torch
.
ByteTensor
):
return
img
.
float
().
div
(
255
)
else
:
return
img
class
ToPILImage
(
object
):
...
...
@@ -67,7 +78,6 @@ class ToPILImage(object):
if
torch
.
is_tensor
(
pic
):
npimg
=
np
.
transpose
(
pic
.
numpy
(),
(
1
,
2
,
0
))
assert
isinstance
(
npimg
,
np
.
ndarray
),
'pic should be Tensor or ndarray'
if
npimg
.
shape
[
2
]
==
1
:
npimg
=
npimg
[:,
:,
0
]
...
...
@@ -83,7 +93,6 @@ class ToPILImage(object):
if
npimg
.
dtype
==
np
.
uint8
:
mode
=
'RGB'
assert
mode
is
not
None
,
'{} is not supported'
.
format
(
npimg
.
dtype
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment