Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76924384
Unverified
Commit
76924384
authored
Dec 19, 2022
by
amyeroberts
Committed by
GitHub
Dec 19, 2022
Browse files
Vilt - use image_transforms pad (#20780)
Use image_transforms pad
parent
ecd7de3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
44 deletions
+30
-44
src/transformers/models/detr/image_processing_detr.py
src/transformers/models/detr/image_processing_detr.py
+3
-0
src/transformers/models/vilt/image_processing_vilt.py
src/transformers/models/vilt/image_processing_vilt.py
+27
-44
No files found.
src/transformers/models/detr/image_processing_detr.py
View file @
76924384
...
@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
...
@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
return
norm_annotation
return
norm_annotation
# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
def
max_across_indices
(
values
:
Iterable
[
Any
])
->
List
[
Any
]:
def
max_across_indices
(
values
:
Iterable
[
Any
])
->
List
[
Any
]:
"""
"""
Return the maximum value across all indices of an iterable of values.
Return the maximum value across all indices of an iterable of values.
...
@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
...
@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def
get_max_height_width
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
def
get_max_height_width
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
"""
"""
Get the maximum height and width across all images in a batch.
Get the maximum height and width across all images in a batch.
...
@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]:
...
@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]:
return
(
max_height
,
max_width
)
return
(
max_height
,
max_width
)
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
"""
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...
...
src/transformers/models/vilt/image_processing_vilt.py
View file @
76924384
...
@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available
...
@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available
from
transformers.utils.generic
import
TensorType
from
transformers.utils.generic
import
TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
normalize
,
rescale
,
resize
,
to_channel_dimension_format
from
...image_transforms
import
PaddingMode
,
normalize
,
pad
,
rescale
,
resize
,
to_channel_dimension_format
from
...image_utils
import
(
from
...image_utils
import
(
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_STD
,
IMAGENET_STANDARD_STD
,
...
@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
...
@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
return
[
max
(
values_i
)
for
values_i
in
zip
(
*
values
)]
def
pad
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
],
input_channel_dimension
:
Optional
[
ChannelDimension
]
=
None
,
data_format
:
Optional
[
ChannelDimension
]
=
None
,
)
->
np
.
ndarray
:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if
input_channel_dimension
is
None
:
input_channel_dimension
=
infer_channel_dimension_format
(
image
)
output_height
,
output_width
=
output_size
input_height
,
input_width
=
get_image_size
(
image
)
pad_bottom
=
output_height
-
input_height
pad_right
=
output_width
-
input_width
if
input_channel_dimension
==
ChannelDimension
.
FIRST
:
padded_image
=
np
.
pad
(
image
,
[(
0
,
0
),
(
0
,
pad_bottom
),
(
0
,
pad_right
)],
mode
=
"constant"
,
constant_values
=
0
)
elif
input_channel_dimension
==
ChannelDimension
.
LAST
:
padded_image
=
np
.
pad
(
image
,
[(
0
,
pad_bottom
),
(
0
,
pad_right
),
(
0
,
0
)],
mode
=
"constant"
,
constant_values
=
0
)
else
:
raise
ValueError
(
f
"Invalid channel dimension format:
{
input_channel_dimension
}
"
)
if
data_format
is
not
None
:
padded_image
=
to_channel_dimension_format
(
padded_image
,
data_format
)
return
padded_image
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
def
make_pixel_mask
(
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
])
->
np
.
ndarray
:
"""
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...
@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
...
@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
return
mask
return
mask
def
get_max_
dimensions
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
def
get_max_
height_width
(
images
:
List
[
np
.
ndarray
])
->
List
[
int
]:
"""
"""
Get the maximum height and width across all images in a batch.
Get the maximum height and width across all images in a batch.
"""
"""
...
@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor):
...
@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor):
"""
"""
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
def
_pad_image
(
self
,
image
:
np
.
ndarray
,
output_size
:
Tuple
[
int
,
int
],
constant_values
:
Union
[
float
,
Iterable
[
float
]]
=
0
,
data_format
:
Optional
[
ChannelDimension
]
=
None
,
)
->
np
.
ndarray
:
"""
Pad an image with zeros to the given size.
"""
input_height
,
input_width
=
get_image_size
(
image
)
output_height
,
output_width
=
output_size
pad_bottom
=
output_height
-
input_height
pad_right
=
output_width
-
input_width
padding
=
((
0
,
pad_bottom
),
(
0
,
pad_right
))
padded_image
=
pad
(
image
,
padding
,
mode
=
PaddingMode
.
CONSTANT
,
constant_values
=
constant_values
,
data_format
=
data_format
)
return
padded_image
def
pad
(
def
pad
(
self
,
self
,
images
:
List
[
np
.
ndarray
],
images
:
List
[
np
.
ndarray
],
...
@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor):
...
@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
"""
pad_size
=
get_max_dimensions
(
images
)
pad_size
=
get_max_height_width
(
images
)
padded_images
=
[
pad
(
image
=
image
,
output_size
=
pad_size
,
data_format
=
data_format
)
for
image
in
images
]
padded_images
=
[
self
.
_pad_image
(
image
=
image
,
output_size
=
pad_size
,
data_format
=
data_format
)
for
image
in
images
]
data
=
{
"pixel_values"
:
padded_images
}
data
=
{
"pixel_values"
:
padded_images
}
if
return_pixel_mask
:
if
return_pixel_mask
:
masks
=
[
make_pixel_mask
(
image
=
image
,
output_size
=
pad_size
)
for
image
in
images
]
masks
=
[
make_pixel_mask
(
image
=
image
,
output_size
=
pad_size
)
for
image
in
images
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment