Unverified Commit ecfa7eb2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166)

* up

* up
parent 439a43b6
...@@ -20,6 +20,7 @@ from collections import OrderedDict ...@@ -20,6 +20,7 @@ from collections import OrderedDict
from typing import List, Union from typing import List, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import CONFIG_NAME
CONFIG_MAPPING_NAMES = OrderedDict( CONFIG_MAPPING_NAMES = OrderedDict(
...@@ -520,6 +521,6 @@ class AutoConfig: ...@@ -520,6 +521,6 @@ class AutoConfig:
raise ValueError( raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. " f"Unrecognized model in {pretrained_model_name_or_path}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings " f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
f"in its name: {', '.join(CONFIG_MAPPING.keys())}" f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
) )
...@@ -20,7 +20,7 @@ from collections import OrderedDict ...@@ -20,7 +20,7 @@ from collections import OrderedDict
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import FEATURE_EXTRACTOR_NAME from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
...@@ -142,7 +142,12 @@ class AutoFeatureExtractor: ...@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
) )
if not is_feature_extraction_file and not is_directory: has_local_config = (
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
)
# load config, if it can be loaded
if not is_feature_extraction_file and (has_local_config or not is_directory):
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
...@@ -150,6 +155,7 @@ class AutoFeatureExtractor: ...@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
model_type = config_class_to_model_type(type(config).__name__) model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None: if model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict: elif "feature_extractor_type" in config_dict:
...@@ -157,7 +163,7 @@ class AutoFeatureExtractor: ...@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
return feature_extractor_class.from_dict(config_dict, **kwargs) return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings " f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}" f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
) )
...@@ -14,15 +14,17 @@ ...@@ -14,15 +14,17 @@
# limitations under the License. # limitations under the License.
import os import os
import tempfile
import unittest import unittest
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
) )
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
...@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory(self): def test_feature_extractor_from_local_directory_from_key(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory_from_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
config_dict.pop("feature_extractor_type")
config = Wav2Vec2FeatureExtractor(config_dict)
# save in new folder
model_config.save_pretrained(tmpdirname)
config.save_pretrained(tmpdirname)
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self): def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment