Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
15c166ac
Unverified
Commit
15c166ac
authored
Nov 08, 2023
by
Philip Meier
Committed by
GitHub
Nov 08, 2023
Browse files
refactor to_pil_image and align array with tensor inputs (#8097)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
a0fcd083
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
33 deletions
+20
-33
test/test_transforms.py
test/test_transforms.py
+6
-4
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+14
-29
No files found.
test/test_transforms.py
View file @
15c166ac
...
@@ -661,7 +661,7 @@ class TestToPil:
...
@@ -661,7 +661,7 @@ class TestToPil:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"img_data, expected_mode"
,
"img_data, expected_mode"
,
[
[
(
torch
.
Tensor
(
4
,
4
,
1
).
uniform_
().
numpy
(),
"
F
"
),
(
torch
.
Tensor
(
4
,
4
,
1
).
uniform_
().
numpy
(),
"
L
"
),
(
torch
.
ByteTensor
(
4
,
4
,
1
).
random_
(
0
,
255
).
numpy
(),
"L"
),
(
torch
.
ByteTensor
(
4
,
4
,
1
).
random_
(
0
,
255
).
numpy
(),
"L"
),
(
torch
.
ShortTensor
(
4
,
4
,
1
).
random_
().
numpy
(),
"I;16"
),
(
torch
.
ShortTensor
(
4
,
4
,
1
).
random_
().
numpy
(),
"I;16"
),
(
torch
.
IntTensor
(
4
,
4
,
1
).
random_
().
numpy
(),
"I"
),
(
torch
.
IntTensor
(
4
,
4
,
1
).
random_
().
numpy
(),
"I"
),
...
@@ -671,6 +671,8 @@ class TestToPil:
...
@@ -671,6 +671,8 @@ class TestToPil:
transform
=
transforms
.
ToPILImage
(
mode
=
expected_mode
)
if
with_mode
else
transforms
.
ToPILImage
()
transform
=
transforms
.
ToPILImage
(
mode
=
expected_mode
)
if
with_mode
else
transforms
.
ToPILImage
()
img
=
transform
(
img_data
)
img
=
transform
(
img_data
)
assert
img
.
mode
==
expected_mode
assert
img
.
mode
==
expected_mode
if
np
.
issubdtype
(
img_data
.
dtype
,
np
.
floating
):
img_data
=
(
img_data
*
255
).
astype
(
np
.
uint8
)
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch
.
testing
.
assert_close
(
img_data
[:,
:,
0
],
np
.
asarray
(
img
).
astype
(
img_data
.
dtype
))
torch
.
testing
.
assert_close
(
img_data
[:,
:,
0
],
np
.
asarray
(
img
).
astype
(
img_data
.
dtype
))
...
@@ -741,7 +743,7 @@ class TestToPil:
...
@@ -741,7 +743,7 @@ class TestToPil:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"img_data, expected_mode"
,
"img_data, expected_mode"
,
[
[
(
torch
.
Tensor
(
4
,
4
).
uniform_
().
numpy
(),
"
F
"
),
(
torch
.
Tensor
(
4
,
4
).
uniform_
().
numpy
(),
"
L
"
),
(
torch
.
ByteTensor
(
4
,
4
).
random_
(
0
,
255
).
numpy
(),
"L"
),
(
torch
.
ByteTensor
(
4
,
4
).
random_
(
0
,
255
).
numpy
(),
"L"
),
(
torch
.
ShortTensor
(
4
,
4
).
random_
().
numpy
(),
"I;16"
),
(
torch
.
ShortTensor
(
4
,
4
).
random_
().
numpy
(),
"I;16"
),
(
torch
.
IntTensor
(
4
,
4
).
random_
().
numpy
(),
"I"
),
(
torch
.
IntTensor
(
4
,
4
).
random_
().
numpy
(),
"I"
),
...
@@ -751,6 +753,8 @@ class TestToPil:
...
@@ -751,6 +753,8 @@ class TestToPil:
transform
=
transforms
.
ToPILImage
(
mode
=
expected_mode
)
if
with_mode
else
transforms
.
ToPILImage
()
transform
=
transforms
.
ToPILImage
(
mode
=
expected_mode
)
if
with_mode
else
transforms
.
ToPILImage
()
img
=
transform
(
img_data
)
img
=
transform
(
img_data
)
assert
img
.
mode
==
expected_mode
assert
img
.
mode
==
expected_mode
if
np
.
issubdtype
(
img_data
.
dtype
,
np
.
floating
):
img_data
=
(
img_data
*
255
).
astype
(
np
.
uint8
)
np
.
testing
.
assert_allclose
(
img_data
,
img
)
np
.
testing
.
assert_allclose
(
img_data
,
img
)
@
pytest
.
mark
.
parametrize
(
"expected_mode"
,
[
None
,
"RGB"
,
"HSV"
,
"YCbCr"
])
@
pytest
.
mark
.
parametrize
(
"expected_mode"
,
[
None
,
"RGB"
,
"HSV"
,
"YCbCr"
])
...
@@ -874,8 +878,6 @@ class TestToPil:
...
@@ -874,8 +878,6 @@ class TestToPil:
trans
(
np
.
ones
([
4
,
4
,
1
],
np
.
uint16
))
trans
(
np
.
ones
([
4
,
4
,
1
],
np
.
uint16
))
with
pytest
.
raises
(
TypeError
,
match
=
reg_msg
):
with
pytest
.
raises
(
TypeError
,
match
=
reg_msg
):
trans
(
np
.
ones
([
4
,
4
,
1
],
np
.
uint32
))
trans
(
np
.
ones
([
4
,
4
,
1
],
np
.
uint32
))
with
pytest
.
raises
(
TypeError
,
match
=
reg_msg
):
trans
(
np
.
ones
([
4
,
4
,
1
],
np
.
float64
))
with
pytest
.
raises
(
ValueError
,
match
=
r
"pic should be 2/3 dimensional. Got \d+ dimensions."
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"pic should be 2/3 dimensional. Got \d+ dimensions."
):
transforms
.
ToPILImage
()(
np
.
ones
([
1
,
4
,
4
,
3
]))
transforms
.
ToPILImage
()(
np
.
ones
([
1
,
4
,
4
,
3
]))
...
...
torchvision/transforms/functional.py
View file @
15c166ac
...
@@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
...
@@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
_log_api_usage_once
(
to_pil_image
)
_log_api_usage_once
(
to_pil_image
)
if
not
(
isinstance
(
pic
,
torch
.
Tensor
)
or
isinstance
(
pic
,
np
.
ndarray
)):
if
isinstance
(
pic
,
torch
.
Tensor
):
if
pic
.
ndim
==
3
:
pic
=
pic
.
permute
((
1
,
2
,
0
))
pic
=
pic
.
numpy
(
force
=
True
)
elif
not
isinstance
(
pic
,
np
.
ndarray
):
raise
TypeError
(
f
"pic should be Tensor or ndarray. Got
{
type
(
pic
)
}
."
)
raise
TypeError
(
f
"pic should be Tensor or ndarray. Got
{
type
(
pic
)
}
."
)
elif
isinstance
(
pic
,
torch
.
Tensor
):
if
pic
.
ndim
==
2
:
if
pic
.
ndimension
()
not
in
{
2
,
3
}:
# if 2D image, add channel dimension (HWC)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndimension
()
}
dimensions."
)
pic
=
np
.
expand_dims
(
pic
,
2
)
if
pic
.
ndim
!=
3
:
elif
pic
.
ndimension
()
==
2
:
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
# if 2D image, add channel dimension (CHW)
pic
=
pic
.
unsqueeze
(
0
)
# check number of channels
if
pic
.
shape
[
-
3
]
>
4
:
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
3
]
}
channels."
)
elif
isinstance
(
pic
,
np
.
ndarray
):
if
pic
.
ndim
not
in
{
2
,
3
}:
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
elif
pic
.
ndim
==
2
:
# if 2D image, add channel dimension (HWC)
pic
=
np
.
expand_dims
(
pic
,
2
)
# check number of channels
if
pic
.
shape
[
-
1
]
>
4
:
if
pic
.
shape
[
-
1
]
>
4
:
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
1
]
}
channels."
)
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
1
]
}
channels."
)
npimg
=
pic
npimg
=
pic
if
isinstance
(
pic
,
torch
.
Tensor
):
if
pic
.
is_floating_point
()
and
mode
!=
"F"
:
pic
=
pic
.
mul
(
255
).
byte
()
npimg
=
np
.
transpose
(
pic
.
cpu
().
numpy
(),
(
1
,
2
,
0
))
if
n
ot
isinstance
(
npimg
,
np
.
ndarray
)
:
if
n
p
.
issubdtype
(
npimg
.
dtype
,
np
.
floating
)
and
mode
!=
"F"
:
raise
TypeError
(
"Input pic must be a torch.Tensor or NumPy ndarray, not {
type(np
img)}"
)
npimg
=
(
npimg
*
255
).
as
type
(
np
.
uint8
)
if
npimg
.
shape
[
2
]
==
1
:
if
npimg
.
shape
[
2
]
==
1
:
expected_mode
=
None
expected_mode
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment