Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6bc6797e
Unverified
Commit
6bc6797e
authored
May 11, 2022
by
Heng Kuan Wee
Committed by
GitHub
May 11, 2022
Browse files
Convert image to rgb for clip model (#17101)
Co-authored-by:
kuanwee.heng
<
kuanwee.heng@aaqua.live
>
parent
0a2bea47
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
86 additions
and
1 deletion
+86
-1
src/transformers/models/clip/feature_extraction_clip.py
src/transformers/models/clip/feature_extraction_clip.py
+21
-1
tests/models/clip/test_feature_extraction_clip.py
tests/models/clip/test_feature_extraction_clip.py
+65
-0
No files found.
src/transformers/models/clip/feature_extraction_clip.py
View file @
6bc6797e
...
...
@@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
convert_rgb (`bool`, defaults to `True`):
Whether or not to convert `PIL.Image.Image` into `RGB` format
"""
model_input_names
=
[
"pixel_values"
]
...
...
@@ -68,6 +70,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
do_normalize
=
True
,
image_mean
=
None
,
image_std
=
None
,
do_convert_rgb
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
...
...
@@ -79,6 +82,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
if
image_mean
is
not
None
else
[
0.48145466
,
0.4578275
,
0.40821073
]
self
.
image_std
=
image_std
if
image_std
is
not
None
else
[
0.26862954
,
0.26130258
,
0.27577711
]
self
.
do_convert_rgb
=
do_convert_rgb
def
__call__
(
self
,
...
...
@@ -141,7 +145,9 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
if
not
is_batched
:
images
=
[
images
]
# transformations (resizing + center cropping + normalization)
# transformations (convert rgb + resizing + center cropping + normalization)
if
self
.
do_convert_rgb
:
images
=
[
self
.
convert_rgb
(
image
)
for
image
in
images
]
if
self
.
do_resize
and
self
.
size
is
not
None
and
self
.
resample
is
not
None
:
images
=
[
self
.
resize
(
image
=
image
,
size
=
self
.
size
,
resample
=
self
.
resample
)
for
image
in
images
]
if
self
.
do_center_crop
and
self
.
crop_size
is
not
None
:
...
...
@@ -155,6 +161,20 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
return
encoded_inputs
def
convert_rgb
(
self
,
image
):
"""
Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to convert.
"""
self
.
_ensure_format_supported
(
image
)
if
not
isinstance
(
image
,
Image
.
Image
):
return
image
return
image
.
convert
(
"RGB"
)
def
center_crop
(
self
,
image
,
size
):
"""
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
...
...
tests/models/clip/test_feature_extraction_clip.py
View file @
6bc6797e
...
...
@@ -49,6 +49,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
do_normalize
=
True
,
image_mean
=
[
0.48145466
,
0.4578275
,
0.40821073
],
image_std
=
[
0.26862954
,
0.26130258
,
0.27577711
],
do_convert_rgb
=
True
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -63,6 +64,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
do_convert_rgb
=
do_convert_rgb
def
prepare_feat_extract_dict
(
self
):
return
{
...
...
@@ -73,6 +75,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
"do_normalize"
:
self
.
do_normalize
,
"image_mean"
:
self
.
image_mean
,
"image_std"
:
self
.
image_std
,
"do_convert_rgb"
:
self
.
do_convert_rgb
,
}
def
prepare_inputs
(
self
,
equal_resolution
=
False
,
numpify
=
False
,
torchify
=
False
):
...
...
@@ -128,6 +131,7 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_normalize"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"image_mean"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"image_std"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_convert_rgb"
))
def
test_batch_feature
(
self
):
pass
...
...
@@ -227,3 +231,64 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self
.
feature_extract_tester
.
crop_size
,
),
)
@
require_torch
@
require_vision
class
CLIPFeatureExtractionTestFourChannels
(
FeatureExtractionSavingTestMixin
,
unittest
.
TestCase
):
feature_extraction_class
=
CLIPFeatureExtractor
if
is_vision_available
()
else
None
def
setUp
(
self
):
self
.
feature_extract_tester
=
CLIPFeatureExtractionTester
(
self
,
num_channels
=
4
)
self
.
expected_encoded_image_num_channels
=
3
@
property
def
feat_extract_dict
(
self
):
return
self
.
feature_extract_tester
.
prepare_feat_extract_dict
()
def
test_feat_extract_properties
(
self
):
feature_extractor
=
self
.
feature_extraction_class
(
**
self
.
feat_extract_dict
)
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_resize"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"size"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_center_crop"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"center_crop"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_normalize"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"image_mean"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"image_std"
))
self
.
assertTrue
(
hasattr
(
feature_extractor
,
"do_convert_rgb"
))
def
test_batch_feature
(
self
):
pass
def
test_call_pil_four_channels
(
self
):
# Initialize feature_extractor
feature_extractor
=
self
.
feature_extraction_class
(
**
self
.
feat_extract_dict
)
# create random PIL images
image_inputs
=
self
.
feature_extract_tester
.
prepare_inputs
(
equal_resolution
=
False
)
for
image
in
image_inputs
:
self
.
assertIsInstance
(
image
,
Image
.
Image
)
# Test not batched input
encoded_images
=
feature_extractor
(
image_inputs
[
0
],
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
1
,
self
.
expected_encoded_image_num_channels
,
self
.
feature_extract_tester
.
crop_size
,
self
.
feature_extract_tester
.
crop_size
,
),
)
# Test batched
encoded_images
=
feature_extractor
(
image_inputs
,
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
self
.
feature_extract_tester
.
batch_size
,
self
.
expected_encoded_image_num_channels
,
self
.
feature_extract_tester
.
crop_size
,
self
.
feature_extract_tester
.
crop_size
,
),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment