Unverified Commit cb555af2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Return input_ids in ImageGPT feature extractor (#16872)

parent e789418e
...@@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix ...@@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
Whether or not to normalize the input to the range between -1 and +1. Whether or not to normalize the input to the range between -1 and +1.
""" """
model_input_names = ["pixel_values"] model_input_names = ["input_ids"]
def __init__(self, clusters, do_resize=True, size=32, resample=Image.BILINEAR, do_normalize=True, **kwargs): def __init__(self, clusters, do_resize=True, size=32, resample=Image.BILINEAR, do_normalize=True, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix ...@@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
Returns: Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields: [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, - **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
width).
""" """
# Input type checking for clearer error # Input type checking for clearer error
valid_images = False valid_images = False
...@@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix ...@@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
images = images.reshape(batch_size, -1) images = images.reshape(batch_size, -1)
# return as BatchFeature # return as BatchFeature
data = {"pixel_values": images} data = {"input_ids": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
...@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase): ...@@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase):
# test non-batched # test non-batched
encoding = feature_extractor(images[0], return_tensors="pt") encoding = feature_extractor(images[0], return_tensors="pt")
self.assertIsInstance(encoding.pixel_values, torch.LongTensor) self.assertIsInstance(encoding.input_ids, torch.LongTensor)
self.assertEqual(encoding.pixel_values.shape, (1, 1024)) self.assertEqual(encoding.input_ids.shape, (1, 1024))
expected_slice = [306, 191, 191] expected_slice = [306, 191, 191]
self.assertEqual(encoding.pixel_values[0, :3].tolist(), expected_slice) self.assertEqual(encoding.input_ids[0, :3].tolist(), expected_slice)
# test batched # test batched
encoding = feature_extractor(images, return_tensors="pt") encoding = feature_extractor(images, return_tensors="pt")
self.assertIsInstance(encoding.pixel_values, torch.LongTensor) self.assertIsInstance(encoding.input_ids, torch.LongTensor)
self.assertEqual(encoding.pixel_values.shape, (2, 1024)) self.assertEqual(encoding.input_ids.shape, (2, 1024))
expected_slice = [303, 13, 13] expected_slice = [303, 13, 13]
self.assertEqual(encoding.pixel_values[1, -3:].tolist(), expected_slice) self.assertEqual(encoding.input_ids[1, -3:].tolist(), expected_slice)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment