Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cb555af2
Unverified
Commit
cb555af2
authored
Apr 21, 2022
by
Sylvain Gugger
Committed by
GitHub
Apr 21, 2022
Browse files
Return input_ids in ImageGPT feature extractor (#16872)
parent
e789418e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
10 deletions
+9
-10
src/transformers/models/imagegpt/feature_extraction_imagegpt.py
...ansformers/models/imagegpt/feature_extraction_imagegpt.py
+3
-4
tests/imagegpt/test_feature_extraction_imagegpt.py
tests/imagegpt/test_feature_extraction_imagegpt.py
+6
-6
No files found.
src/transformers/models/imagegpt/feature_extraction_imagegpt.py
View file @
cb555af2
...
@@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
...
@@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
Whether or not to normalize the input to the range between -1 and +1.
Whether or not to normalize the input to the range between -1 and +1.
"""
"""
model_input_names
=
[
"
pixel_value
s"
]
model_input_names
=
[
"
input_id
s"
]
def
__init__
(
self
,
clusters
,
do_resize
=
True
,
size
=
32
,
resample
=
Image
.
BILINEAR
,
do_normalize
=
True
,
**
kwargs
):
def
__init__
(
self
,
clusters
,
do_resize
=
True
,
size
=
32
,
resample
=
Image
.
BILINEAR
,
do_normalize
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
@@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
...
@@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
Returns:
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
width).
"""
"""
# Input type checking for clearer error
# Input type checking for clearer error
valid_images
=
False
valid_images
=
False
...
@@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
...
@@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
images
=
images
.
reshape
(
batch_size
,
-
1
)
images
=
images
.
reshape
(
batch_size
,
-
1
)
# return as BatchFeature
# return as BatchFeature
data
=
{
"
pixel_value
s"
:
images
}
data
=
{
"
input_id
s"
:
images
}
encoded_inputs
=
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
encoded_inputs
=
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
return
encoded_inputs
return
encoded_inputs
tests/imagegpt/test_feature_extraction_imagegpt.py
View file @
cb555af2
...
@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase):
...
@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase):
# test non-batched
# test non-batched
encoding
=
feature_extractor
(
images
[
0
],
return_tensors
=
"pt"
)
encoding
=
feature_extractor
(
images
[
0
],
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
encoding
.
pixel_value
s
,
torch
.
LongTensor
)
self
.
assertIsInstance
(
encoding
.
input_id
s
,
torch
.
LongTensor
)
self
.
assertEqual
(
encoding
.
pixel_value
s
.
shape
,
(
1
,
1024
))
self
.
assertEqual
(
encoding
.
input_id
s
.
shape
,
(
1
,
1024
))
expected_slice
=
[
306
,
191
,
191
]
expected_slice
=
[
306
,
191
,
191
]
self
.
assertEqual
(
encoding
.
pixel_value
s
[
0
,
:
3
].
tolist
(),
expected_slice
)
self
.
assertEqual
(
encoding
.
input_id
s
[
0
,
:
3
].
tolist
(),
expected_slice
)
# test batched
# test batched
encoding
=
feature_extractor
(
images
,
return_tensors
=
"pt"
)
encoding
=
feature_extractor
(
images
,
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
encoding
.
pixel_value
s
,
torch
.
LongTensor
)
self
.
assertIsInstance
(
encoding
.
input_id
s
,
torch
.
LongTensor
)
self
.
assertEqual
(
encoding
.
pixel_value
s
.
shape
,
(
2
,
1024
))
self
.
assertEqual
(
encoding
.
input_id
s
.
shape
,
(
2
,
1024
))
expected_slice
=
[
303
,
13
,
13
]
expected_slice
=
[
303
,
13
,
13
]
self
.
assertEqual
(
encoding
.
pixel_value
s
[
1
,
-
3
:].
tolist
(),
expected_slice
)
self
.
assertEqual
(
encoding
.
input_id
s
[
1
,
-
3
:].
tolist
(),
expected_slice
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment