Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
491e9518
Unverified
Commit
491e9518
authored
Dec 15, 2022
by
amyeroberts
Committed by
GitHub
Dec 15, 2022
Browse files
Move convert_to_rgb to image_transforms module (#20784)
* Move convert_to_rgb to image_transforms module * Fix tests
parent
4bc723f8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
58 additions
and
61 deletions
+58
-61
src/transformers/image_transforms.py
src/transformers/image_transforms.py
+20
-0
src/transformers/models/bit/image_processing_bit.py
src/transformers/models/bit/image_processing_bit.py
+2
-15
src/transformers/models/chinese_clip/image_processing_chinese_clip.py
...mers/models/chinese_clip/image_processing_chinese_clip.py
+2
-15
src/transformers/models/clip/image_processing_clip.py
src/transformers/models/clip/image_processing_clip.py
+2
-15
src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py
...sformers/models/vit_hybrid/image_processing_vit_hybrid.py
+2
-16
tests/test_image_transforms.py
tests/test_image_transforms.py
+30
-0
No files found.
src/transformers/image_transforms.py
View file @
491e9518
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
from
transformers.image_utils
import
(
from
transformers.image_utils
import
(
ChannelDimension
,
ChannelDimension
,
ImageInput
,
get_channel_dimension_axis
,
get_channel_dimension_axis
,
get_image_size
,
get_image_size
,
infer_channel_dimension_format
,
infer_channel_dimension_format
,
...
@@ -687,3 +688,22 @@ def pad(
...
@@ -687,3 +688,22 @@ def pad(
image
=
to_channel_dimension_format
(
image
,
data_format
)
if
data_format
is
not
None
else
image
image
=
to_channel_dimension_format
(
image
,
data_format
)
if
data_format
is
not
None
else
image
return
image
return
image
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
def
convert_to_rgb
(
image
:
ImageInput
)
->
ImageInput
:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
"""
requires_backends
(
convert_to_rgb
,
[
"vision"
])
if
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
image
image
=
image
.
convert
(
"RGB"
)
return
image
src/transformers/models/bit/image_processing_bit.py
View file @
491e9518
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Image processor class for BiT."""
"""Image processor class for BiT."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
(
from
...image_transforms
import
(
center_crop
,
center_crop
,
convert_to_rgb
,
get_resize_output_image_size
,
get_resize_output_image_size
,
normalize
,
normalize
,
rescale
,
rescale
,
...
@@ -41,20 +42,6 @@ if is_vision_available():
...
@@ -41,20 +42,6 @@ if is_vision_available():
import
PIL
import
PIL
def
convert_to_rgb
(
image
:
Union
[
Any
,
PIL
.
Image
.
Image
])
->
Union
[
Any
,
PIL
.
Image
.
Image
]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
image
return
image
.
convert
(
"RGB"
)
class
BitImageProcessor
(
BaseImageProcessor
):
class
BitImageProcessor
(
BaseImageProcessor
):
r
"""
r
"""
Constructs a BiT image processor.
Constructs a BiT image processor.
...
...
src/transformers/models/chinese_clip/image_processing_chinese_clip.py
View file @
491e9518
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Image processor class for Chinese-CLIP."""
"""Image processor class for Chinese-CLIP."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
(
from
...image_transforms
import
(
center_crop
,
center_crop
,
convert_to_rgb
,
get_resize_output_image_size
,
get_resize_output_image_size
,
normalize
,
normalize
,
rescale
,
rescale
,
...
@@ -41,20 +42,6 @@ if is_vision_available():
...
@@ -41,20 +42,6 @@ if is_vision_available():
import
PIL
import
PIL
def
convert_to_rgb
(
image
:
Union
[
Any
,
PIL
.
Image
.
Image
])
->
Union
[
Any
,
PIL
.
Image
.
Image
]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
image
return
image
.
convert
(
"RGB"
)
class
ChineseCLIPImageProcessor
(
BaseImageProcessor
):
class
ChineseCLIPImageProcessor
(
BaseImageProcessor
):
r
"""
r
"""
Constructs a Chinese-CLIP image processor.
Constructs a Chinese-CLIP image processor.
...
...
src/transformers/models/clip/image_processing_clip.py
View file @
491e9518
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Image processor class for CLIP."""
"""Image processor class for CLIP."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
(
from
...image_transforms
import
(
center_crop
,
center_crop
,
convert_to_rgb
,
get_resize_output_image_size
,
get_resize_output_image_size
,
normalize
,
normalize
,
rescale
,
rescale
,
...
@@ -41,20 +42,6 @@ if is_vision_available():
...
@@ -41,20 +42,6 @@ if is_vision_available():
import
PIL
import
PIL
def
convert_to_rgb
(
image
:
Union
[
Any
,
PIL
.
Image
.
Image
])
->
Union
[
Any
,
PIL
.
Image
.
Image
]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
image
return
image
.
convert
(
"RGB"
)
class
CLIPImageProcessor
(
BaseImageProcessor
):
class
CLIPImageProcessor
(
BaseImageProcessor
):
r
"""
r
"""
Constructs a CLIP image processor.
Constructs a CLIP image processor.
...
...
src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py
View file @
491e9518
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Image processor class for ViT hybrid."""
"""Image processor class for ViT hybrid."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
...
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
(
from
...image_transforms
import
(
center_crop
,
center_crop
,
convert_to_rgb
,
get_resize_output_image_size
,
get_resize_output_image_size
,
normalize
,
normalize
,
rescale
,
rescale
,
...
@@ -41,21 +42,6 @@ if is_vision_available():
...
@@ -41,21 +42,6 @@ if is_vision_available():
import
PIL
import
PIL
# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb
def
convert_to_rgb
(
image
:
Union
[
Any
,
PIL
.
Image
.
Image
])
->
Union
[
Any
,
PIL
.
Image
.
Image
]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
image
return
image
.
convert
(
"RGB"
)
class
ViTHybridImageProcessor
(
BaseImageProcessor
):
class
ViTHybridImageProcessor
(
BaseImageProcessor
):
r
"""
r
"""
Constructs a ViT Hybrid image processor.
Constructs a ViT Hybrid image processor.
...
...
tests/test_image_transforms.py
View file @
491e9518
...
@@ -37,6 +37,7 @@ if is_vision_available():
...
@@ -37,6 +37,7 @@ if is_vision_available():
from
transformers.image_transforms
import
(
from
transformers.image_transforms
import
(
center_crop
,
center_crop
,
center_to_corners_format
,
center_to_corners_format
,
convert_to_rgb
,
corners_to_center_format
,
corners_to_center_format
,
get_resize_output_image_size
,
get_resize_output_image_size
,
id_to_rgb
,
id_to_rgb
,
...
@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
...
@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
self
.
assertTrue
(
self
.
assertTrue
(
np
.
allclose
(
expected_image
,
pad
(
image
,
((
0
,
2
),
(
2
,
1
)),
mode
=
"reflect"
,
data_format
=
"channels_last"
))
np
.
allclose
(
expected_image
,
pad
(
image
,
((
0
,
2
),
(
2
,
1
)),
mode
=
"reflect"
,
data_format
=
"channels_last"
))
)
)
@
require_vision
def
test_convert_to_rgb
(
self
):
# Test that an RGBA image is converted to RGB
image
=
np
.
array
([[[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
]]],
dtype
=
np
.
uint8
)
pil_image
=
PIL
.
Image
.
fromarray
(
image
)
self
.
assertEqual
(
pil_image
.
mode
,
"RGBA"
)
self
.
assertEqual
(
pil_image
.
size
,
(
2
,
1
))
# For the moment, numpy images are returned as is
rgb_image
=
convert_to_rgb
(
image
)
self
.
assertEqual
(
rgb_image
.
shape
,
(
1
,
2
,
4
))
self
.
assertTrue
(
np
.
allclose
(
rgb_image
,
image
))
# And PIL images are converted
rgb_image
=
convert_to_rgb
(
pil_image
)
self
.
assertEqual
(
rgb_image
.
mode
,
"RGB"
)
self
.
assertEqual
(
rgb_image
.
size
,
(
2
,
1
))
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
rgb_image
),
np
.
array
([[[
1
,
2
,
3
],
[
5
,
6
,
7
]]],
dtype
=
np
.
uint8
)))
# Test that a grayscale image is converted to RGB
image
=
np
.
array
([[
0
,
255
]],
dtype
=
np
.
uint8
)
pil_image
=
PIL
.
Image
.
fromarray
(
image
)
self
.
assertEqual
(
pil_image
.
mode
,
"L"
)
self
.
assertEqual
(
pil_image
.
size
,
(
2
,
1
))
rgb_image
=
convert_to_rgb
(
pil_image
)
self
.
assertEqual
(
rgb_image
.
mode
,
"RGB"
)
self
.
assertEqual
(
rgb_image
.
size
,
(
2
,
1
))
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
rgb_image
),
np
.
array
([[[
0
,
0
,
0
],
[
255
,
255
,
255
]]],
dtype
=
np
.
uint8
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment