Unverified Commit 448e050b authored by Naga Sai Abhinay's avatar Naga Sai Abhinay Committed by GitHub
Browse files

Make ImageProcessorMixin compatible with subfolder kwarg (#21725)

* Add subfolder support

* Add kwarg docstring

* formatting fix

* Add test
parent 064f3748
...@@ -128,6 +128,9 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -128,6 +128,9 @@ class ImageProcessingMixin(PushToHubMixin):
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
`kwargs` which has not been used to update `image_processor` and is otherwise ignored. `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are image processor attributes will be used to override the The values in kwargs of any keys which are image processor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
...@@ -221,6 +224,9 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -221,6 +224,9 @@ class ImageProcessingMixin(PushToHubMixin):
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Returns: Returns:
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
...@@ -232,6 +238,7 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -232,6 +238,7 @@ class ImageProcessingMixin(PushToHubMixin):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
...@@ -269,6 +276,7 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -269,6 +276,7 @@ class ImageProcessingMixin(PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder,
) )
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
......
...@@ -311,3 +311,14 @@ class ImageProcessorPushToHubTester(unittest.TestCase): ...@@ -311,3 +311,14 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
) )
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
def test_image_processor_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
config = AutoImageProcessor.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
)
self.assertIsNotNone(config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment