"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "492bb6aa486856f8243dfeb533ed1b23e996e403"
Unverified Commit cd09a8df authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Feature Extractors] Fix kwargs to pre-trained (#30260)

fixes
parent 4ab7a282
...@@ -566,17 +566,17 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -566,17 +566,17 @@ class FeatureExtractionMixin(PushToHubMixin):
""" """
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
feature_extractor = cls(**feature_extractor_dict)
# Update feature_extractor with kwargs if needed # Update feature_extractor with kwargs if needed
to_remove = [] to_remove = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(feature_extractor, key): if key in feature_extractor_dict:
setattr(feature_extractor, key, value) feature_extractor_dict[key] = value
to_remove.append(key) to_remove.append(key)
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
feature_extractor = cls(**feature_extractor_dict)
logger.info(f"Feature extractor {feature_extractor}") logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs: if return_unused_kwargs:
return feature_extractor, kwargs return feature_extractor, kwargs
......
...@@ -142,6 +142,20 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -142,6 +142,20 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertTrue(np.allclose(mel_1, mel_2)) self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second) self.assertEqual(dict_first, dict_second)
def test_feat_extract_from_pretrained_kwargs(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
)
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
def test_call(self): def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus # Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment