Unverified Commit 3a4376d0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683)

parent cdc51ffd
...@@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
else: else:
# BeamSearchDecoderCTC has no auto class # BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None) kwargs.pop("_from_auto", None)
# snapshot_download has no `trust_remote_code` flag
kwargs.pop("trust_remote_code", None)
# make sure that only relevant filenames are downloaded # make sure that only relevant filenames are downloaded
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*") language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
......
...@@ -22,6 +22,7 @@ from pathlib import Path ...@@ -22,6 +22,7 @@ from pathlib import Path
import numpy as np import numpy as np
from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
...@@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# test that both decoder form hub and local files in cache are the same # test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files) self.assertListEqual(local_decoder_files, expected_decoder_files)
def test_processor_from_auto_processor(self):
processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
processor_auto = AutoProcessor.from_pretrained("hf-internal-testing/processor_with_lm")
raw_speech = floats_list((3, 1000))
input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np")
input_auto = processor_auto(raw_speech, return_tensors="np")
for key in input_wav2vec2.keys():
self.assertAlmostEqual(input_wav2vec2[key].sum(), input_auto[key].sum(), delta=1e-2)
logits = self._get_dummy_logits()
decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits)
decoded_auto = processor_auto.batch_decode(logits)
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment