"vscode:/vscode.git/clone" did not exist on "08d609bfb8fbbaf508ae55c5cf414b262cc04061"
Unverified Commit ecfa7eb2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166)

* up

* up
parent 439a43b6
......@@ -20,6 +20,7 @@ from collections import OrderedDict
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...file_utils import CONFIG_NAME
CONFIG_MAPPING_NAMES = OrderedDict(
......@@ -520,6 +521,6 @@ class AutoConfig:
raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. "
"Should have a `model_type` key in its config.json, or contain one of the following strings "
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
)
......@@ -20,7 +20,7 @@ from collections import OrderedDict
# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import FEATURE_EXTRACTOR_NAME
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
......@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)
if not is_feature_extraction_file and not is_directory:
has_local_config = (
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
)
# load config, if it can be loaded
if not is_feature_extraction_file and (has_local_config or not is_directory):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
......@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
......@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
)
......@@ -14,15 +14,17 @@
# limitations under the License.
import os
import tempfile
import unittest
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoFeatureExtractorTest(unittest.TestCase):
......@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory(self):
def test_feature_extractor_from_local_directory_from_key(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory_from_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
config_dict.pop("feature_extractor_type")
config = Wav2Vec2FeatureExtractor(config_dict)
# save in new folder
model_config.save_pretrained(tmpdirname)
config.save_pretrained(tmpdirname)
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment