Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76924384
Unverified
Commit
76924384
authored
Dec 19, 2022
by
amyeroberts
Committed by
GitHub
Dec 19, 2022
Browse files
Vilt - use image_transforms pad (#20780)
Use image_transforms pad
parent
ecd7de3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
44 deletions
+30
-44
src/transformers/models/detr/image_processing_detr.py
src/transformers/models/detr/image_processing_detr.py
+3
-0
src/transformers/models/vilt/image_processing_vilt.py
src/transformers/models/vilt/image_processing_vilt.py
+27
-44
No files found.
src/transformers/models/detr/image_processing_detr.py
View file @
76924384
...
...
@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
return
norm_annotation
# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
def
max_across_indices
(
values
:
Iterable
[
Any
])
->
List
[
Any
]:
"""
Return the maximum value across all indices of an iterable of values.
...
...
@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def
get_max_height_width
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
"""
Get the maximum height and width across all images in a batch.
...
...
@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]:
return
(
max_height
,
max_width
)
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...
...
src/transformers/models/vilt/image_processing_vilt.py
View file @
76924384
...
...
@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available
from
transformers.utils.generic
import
TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
normalize
,
rescale
,
resize
,
to_channel_dimension_format
from
...image_transforms
import
PaddingMode
,
normalize
,
pad
,
rescale
,
resize
,
to_channel_dimension_format
from
...image_utils
import
(
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_STD
,
...
...
@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
def
pad
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
],
input_channel_dimension
:
Optional
[
ChannelDimension
]
=
None
,
data_format
:
Optional
[
ChannelDimension
]
=
None
,
)
->
np
.
ndarray
:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if
input_channel_dimension
is
None
:
input_channel_dimension
=
infer_channel_dimension_format
(
image
)
output_height
,
output_width
=
output_size
input_height
,
input_width
=
get_image_size
(
image
)
pad_bottom
=
output_height
-
input_height
pad_right
=
output_width
-
input_width
if
input_channel_dimension
==
ChannelDimension
.
FIRST
:
padded_image
=
np
.
pad
(
image
,
[(
0
,
0
),
(
0
,
pad_bottom
),
(
0
,
pad_right
)],
mode
=
"constant"
,
constant_values
=
0
)
elif
input_channel_dimension
==
ChannelDimension
.
LAST
:
padded_image
=
np
.
pad
(
image
,
[(
0
,
pad_bottom
),
(
0
,
pad_right
),
(
0
,
0
)],
mode
=
"constant"
,
constant_values
=
0
)
else
:
raise
ValueError
(
f
"Invalid channel dimension format:
{
input_channel_dimension
}
"
)
if
data_format
is
not
None
:
padded_image
=
to_channel_dimension_format
(
padded_image
,
data_format
)
return
padded_image
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...
...
@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
return
mask
def
get_max_
dimensions
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
def
get_max_
height_width
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
"""
Get the maximum height and width across all images in a batch.
"""
...
...
@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor):
"""
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
def
_pad_image
(
self
,
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
],
constant_values
:
Union
[
float
,
Iterable
[
float
]]
=
0
,
data_format
:
Optional
[
ChannelDimension
]
=
None
,
)
->
np
.
ndarray
:
"""
Pad an image with zeros to the given size.
"""
input_height
,
input_width
=
get_image_size
(
image
)
output_height
,
output_width
=
output_size
pad_bottom
=
output_height
-
input_height
pad_right
=
output_width
-
input_width
padding
=
((
0
,
pad_bottom
),
(
0
,
pad_right
))
padded_image
=
pad
(
image
,
padding
,
mode
=
PaddingMode
.
CONSTANT
,
constant_values
=
constant_values
,
data_format
=
data_format
)
return
padded_image
def
pad
(
self
,
images
:
List
[
np
.
ndarray
],
...
...
@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
pad_size
=
get_max_dimensions
(
images
)
padded_images
=
[
pad
(
image
=
image
,
output_size
=
pad_size
,
data_format
=
data_format
)
for
image
in
images
]
pad_size
=
get_max_height_width
(
images
)
padded_images
=
[
self
.
_pad_image
(
image
=
image
,
output_size
=
pad_size
,
data_format
=
data_format
)
for
image
in
images
]
data
=
{
"pixel_values"
:
padded_images
}
if
return_pixel_mask
:
masks
=
[
make_pixel_mask
(
image
=
image
,
output_size
=
pad_size
)
for
image
in
images
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment