Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
944ddce8
Unverified
Commit
944ddce8
authored
Aug 09, 2023
by
amyeroberts
Committed by
GitHub
Aug 09, 2023
Browse files
Enable passing number of channels when inferring data format (#25412)
parent
cb3c821c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
3 deletions
+14
-3
src/transformers/image_utils.py
src/transformers/image_utils.py
+10
-3
tests/utils/test_image_utils.py
tests/utils/test_image_utils.py
+4
-0
No files found.
src/transformers/image_utils.py
View file @
944ddce8
...
@@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
...
@@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
return
to_numpy
(
img
)
return
to_numpy
(
img
)
def
infer_channel_dimension_format
(
image
:
np
.
ndarray
)
->
ChannelDimension
:
def
infer_channel_dimension_format
(
image
:
np
.
ndarray
,
num_channels
:
Optional
[
Union
[
int
,
Tuple
[
int
,
...]]]
=
None
)
->
ChannelDimension
:
"""
"""
Infers the channel dimension format of `image`.
Infers the channel dimension format of `image`.
Args:
Args:
image (`np.ndarray`):
image (`np.ndarray`):
The image to infer the channel dimension of.
The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns:
Returns:
The channel dimension of the image.
The channel dimension of the image.
"""
"""
num_channels
=
num_channels
if
num_channels
is
not
None
else
(
1
,
3
)
num_channels
=
(
num_channels
,)
if
isinstance
(
num_channels
,
int
)
else
num_channels
if
image
.
ndim
==
3
:
if
image
.
ndim
==
3
:
first_dim
,
last_dim
=
0
,
2
first_dim
,
last_dim
=
0
,
2
elif
image
.
ndim
==
4
:
elif
image
.
ndim
==
4
:
...
@@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
...
@@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
else
:
else
:
raise
ValueError
(
f
"Unsupported number of image dimensions:
{
image
.
ndim
}
"
)
raise
ValueError
(
f
"Unsupported number of image dimensions:
{
image
.
ndim
}
"
)
if
image
.
shape
[
first_dim
]
in
(
1
,
3
)
:
if
image
.
shape
[
first_dim
]
in
num_channels
:
return
ChannelDimension
.
FIRST
return
ChannelDimension
.
FIRST
elif
image
.
shape
[
last_dim
]
in
(
1
,
3
)
:
elif
image
.
shape
[
last_dim
]
in
num_channels
:
return
ChannelDimension
.
LAST
return
ChannelDimension
.
LAST
raise
ValueError
(
"Unable to infer channel dimension format"
)
raise
ValueError
(
"Unable to infer channel dimension format"
)
...
...
tests/utils/test_image_utils.py
View file @
944ddce8
...
@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
...
@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
infer_channel_dimension_format
(
np
.
random
.
randint
(
0
,
256
,
(
10
,
1
,
50
)))
infer_channel_dimension_format
(
np
.
random
.
randint
(
0
,
256
,
(
10
,
1
,
50
)))
# But if we explicitly set one of the number of channels to 50 it works
inferred_dim
=
infer_channel_dimension_format
(
np
.
random
.
randint
(
0
,
256
,
(
10
,
1
,
50
)),
num_channels
=
50
)
self
.
assertEqual
(
inferred_dim
,
ChannelDimension
.
LAST
)
# Test we correctly identify the channel dimension
# Test we correctly identify the channel dimension
image
=
np
.
random
.
randint
(
0
,
256
,
(
3
,
4
,
5
))
image
=
np
.
random
.
randint
(
0
,
256
,
(
3
,
4
,
5
))
inferred_dim
=
infer_channel_dimension_format
(
image
)
inferred_dim
=
infer_channel_dimension_format
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment