"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ca169dbdf189d30b3aacfd7cc50d668121361ec8"
Unverified Commit efb35a41 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)

parent 6ea62666
...@@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM: ...@@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM:
# BeamSearchDecoderCTC has no auto class # BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None) kwargs.pop("_from_auto", None)
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs) # make sure that only relevant filenames are downloaded
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_filenames, alphabet_filename]
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
pretrained_model_name_or_path, allow_regex=allow_regex, **kwargs
)
# set language model attributes # set language model attributes
for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]: for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]:
......
...@@ -18,6 +18,7 @@ import shutil ...@@ -18,6 +18,7 @@ import shutil
import tempfile import tempfile
import unittest import unittest
from multiprocessing import Pool from multiprocessing import Pool
from pathlib import Path
import numpy as np import numpy as np
...@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor) self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
def test_decoder_download_ignores_files(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
language_model = processor.decoder.model_container[processor.decoder._model_key]
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
downloaded_decoder_files = os.listdir(path_to_cached_dir)
# test that only decoder relevant files from
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self.assertListEqual(downloaded_decoder_files, ["alphabet.json", "language_model"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment