"examples/pytorch/text-classification/run_glue.py" did not exist on "dfd16af8322788e6dd58e8396e0d6f2f5312bf99"
Unverified Commit 9eb7e9ba authored by Javier de la Rosa's avatar Javier de la Rosa Committed by GitHub
Browse files

Fix ASR pipelines from local directories with wav2vec models that have...

Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)

* Fix loading pipelines with wav2vec models with lm when in local paths

* Adding tests

* Fix test

* Adding tests

* Flake8 fixes

* Removing conflict files :(

* Adding task type to test

* Remove unnecessary test and imports
parent e1cbc073
......@@ -127,7 +127,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
feature_extractor, tokenizer = super()._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else:
# BeamSearchDecoderCTC has no auto class
......
......@@ -621,15 +621,20 @@ def pipeline(
import kenlm # to trigger `ImportError` if not installed
from pyctcdecode import BeamSearchDecoderCTC
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
if os.path.isdir(model_name) or os.path.isfile(model_name):
decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
else:
language_model_glob = os.path.join(
BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
)
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_glob, alphabet_filename]
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
kwargs["decoder"] = decoder
except ImportError as e:
logger.warning(
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)
if task == "translation" and model.config.task_specific_params:
......
......@@ -18,6 +18,7 @@ import numpy as np
import pytest
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
......@@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@require_pyctcdecode
def test_with_local_lm_fast(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model=local_dir,
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_chunking(self):
......
......@@ -31,6 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
......@@ -303,3 +304,20 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self.assertListEqual(downloaded_decoder_files, expected_decoder_files)
def test_decoder_local_files(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained(local_dir)
language_model = processor.decoder.model_container[processor.decoder._model_key]
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
local_decoder_files = os.listdir(local_dir)
expected_decoder_files = os.listdir(path_to_cached_dir)
local_decoder_files.sort()
expected_decoder_files.sort()
# test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment