Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ecfa7eb2
Unverified
Commit
ecfa7eb2
authored
Aug 18, 2021
by
Patrick von Platen
Committed by
GitHub
Aug 18, 2021
Browse files
[AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166)
* up * up
parent
439a43b6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
8 deletions
+34
-8
src/transformers/models/auto/configuration_auto.py
src/transformers/models/auto/configuration_auto.py
+2
-1
src/transformers/models/auto/feature_extraction_auto.py
src/transformers/models/auto/feature_extraction_auto.py
+11
-5
tests/test_feature_extraction_auto.py
tests/test_feature_extraction_auto.py
+21
-2
No files found.
src/transformers/models/auto/configuration_auto.py
View file @
ecfa7eb2
...
@@ -20,6 +20,7 @@ from collections import OrderedDict
...
@@ -20,6 +20,7 @@ from collections import OrderedDict
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
...configuration_utils
import
PretrainedConfig
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
CONFIG_NAME
CONFIG_MAPPING_NAMES
=
OrderedDict
(
CONFIG_MAPPING_NAMES
=
OrderedDict
(
...
@@ -520,6 +521,6 @@ class AutoConfig:
...
@@ -520,6 +521,6 @@ class AutoConfig:
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized model in
{
pretrained_model_name_or_path
}
. "
f
"Unrecognized model in
{
pretrained_model_name_or_path
}
. "
"Should have a `model_type` key in its
config.json
, or contain one of the following strings "
f
"Should have a `model_type` key in its
{
CONFIG_NAME
}
, or contain one of the following strings "
f
"in its name:
{
', '
.
join
(
CONFIG_MAPPING
.
keys
())
}
"
f
"in its name:
{
', '
.
join
(
CONFIG_MAPPING
.
keys
())
}
"
)
)
src/transformers/models/auto/feature_extraction_auto.py
View file @
ecfa7eb2
...
@@ -20,7 +20,7 @@ from collections import OrderedDict
...
@@ -20,7 +20,7 @@ from collections import OrderedDict
# Build the list of all feature extractors
# Build the list of all feature extractors
from
...configuration_utils
import
PretrainedConfig
from
...configuration_utils
import
PretrainedConfig
from
...feature_extraction_utils
import
FeatureExtractionMixin
from
...feature_extraction_utils
import
FeatureExtractionMixin
from
...file_utils
import
FEATURE_EXTRACTOR_NAME
from
...file_utils
import
CONFIG_NAME
,
FEATURE_EXTRACTOR_NAME
from
.auto_factory
import
_LazyAutoMapping
from
.auto_factory
import
_LazyAutoMapping
from
.configuration_auto
import
(
from
.configuration_auto
import
(
CONFIG_MAPPING_NAMES
,
CONFIG_MAPPING_NAMES
,
...
@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
...
@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
os
.
path
.
join
(
pretrained_model_name_or_path
,
FEATURE_EXTRACTOR_NAME
)
os
.
path
.
join
(
pretrained_model_name_or_path
,
FEATURE_EXTRACTOR_NAME
)
)
)
if
not
is_feature_extraction_file
and
not
is_directory
:
has_local_config
=
(
os
.
path
.
exists
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
))
if
is_directory
else
False
)
# load config, if it can be loaded
if
not
is_feature_extraction_file
and
(
has_local_config
or
not
is_directory
):
if
not
isinstance
(
config
,
PretrainedConfig
):
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
...
@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
...
@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
config_dict
,
_
=
FeatureExtractionMixin
.
get_feature_extractor_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
config_dict
,
_
=
FeatureExtractionMixin
.
get_feature_extractor_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
model_type
=
config_class_to_model_type
(
type
(
config
).
__name__
)
model_type
=
config_class_to_model_type
(
type
(
config
).
__name__
)
if
model_type
is
not
None
:
if
model_type
is
not
None
:
return
FEATURE_EXTRACTOR_MAPPING
[
type
(
config
)].
from_dict
(
config_dict
,
**
kwargs
)
return
FEATURE_EXTRACTOR_MAPPING
[
type
(
config
)].
from_dict
(
config_dict
,
**
kwargs
)
elif
"feature_extractor_type"
in
config_dict
:
elif
"feature_extractor_type"
in
config_dict
:
...
@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
...
@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
return
feature_extractor_class
.
from_dict
(
config_dict
,
**
kwargs
)
return
feature_extractor_class
.
from_dict
(
config_dict
,
**
kwargs
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized
model
in
{
pretrained_model_name_or_path
}
. Should have a `feature_extractor_type` key in "
f
"Unrecognized
feature extractor
in
{
pretrained_model_name_or_path
}
. Should have a `feature_extractor_type` key in "
f
"its
{
FEATURE_EXTRACTOR_NAME
}
, or
contain
one of the following
strings
"
f
"its
{
FEATURE_EXTRACTOR_NAME
}
, or one of the following
`model_type` keys in its
{
CONFIG_NAME
}
:
"
f
"
in its name:
{
', '
.
join
(
FEATURE_EXTRACTOR_MAPPING
.
keys
())
}
"
f
"
{
', '
.
join
(
c
for
c
in
FEATURE_EXTRACTOR_MAPPING
_NAMES
.
keys
())
}
"
)
)
tests/test_feature_extraction_auto.py
View file @
ecfa7eb2
...
@@ -14,15 +14,17 @@
...
@@ -14,15 +14,17 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
tempfile
import
unittest
import
unittest
from
transformers
import
AutoFeatureExtractor
,
Wav2Vec2FeatureExtractor
from
transformers
import
AutoFeatureExtractor
,
Wav2Vec2Config
,
Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
SAMPLE_FEATURE_EXTRACTION_CONFIG
=
os
.
path
.
join
(
SAMPLE_FEATURE_EXTRACTION_CONFIG
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/dummy_feature_extractor_config.json"
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/dummy_feature_extractor_config.json"
)
)
SAMPLE_CONFIG
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/dummy-config.json"
)
class
AutoFeatureExtractorTest
(
unittest
.
TestCase
):
class
AutoFeatureExtractorTest
(
unittest
.
TestCase
):
...
@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
...
@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config
=
AutoFeatureExtractor
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
config
=
AutoFeatureExtractor
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
def
test_feature_extractor_from_local_directory
(
self
):
def
test_feature_extractor_from_local_directory
_from_key
(
self
):
config
=
AutoFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
config
=
AutoFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
def
test_feature_extractor_from_local_directory_from_config
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model_config
=
Wav2Vec2Config
()
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
config_dict
=
AutoFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
).
to_dict
()
config_dict
.
pop
(
"feature_extractor_type"
)
config
=
Wav2Vec2FeatureExtractor
(
config_dict
)
# save in new folder
model_config
.
save_pretrained
(
tmpdirname
)
config
.
save_pretrained
(
tmpdirname
)
config
=
AutoFeatureExtractor
.
from_pretrained
(
tmpdirname
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
def
test_feature_extractor_from_local_file
(
self
):
def
test_feature_extractor_from_local_file
(
self
):
config
=
AutoFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG
)
config
=
AutoFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
self
.
assertIsInstance
(
config
,
Wav2Vec2FeatureExtractor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment