Unverified Commit 961732c2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] PyCTCDecode Integration to support language model boosted decoding (#14339)



* up

* up

* up

* make it cleaner

* correct

* make styhahalal

* add more tests

* finish

* small fix

* make style

* up

* tryout to solve cicrle ci

* up

* fix more tests

* fix more tests

* apply sylvains suggestions

* fix import

* correct docs

* add pyctcdecode only to speech tests

* fix more tests

* add tf, flax and pt tests

* add pt

* fix last tests

* fix more tests

* Apply suggestions from code review

* change lines

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>

* correct tests

* correct tests

* add doc string
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>
parent 2e12d90b
......@@ -83,6 +83,7 @@ jobs:
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install tensorflow_probability
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
......@@ -151,6 +152,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
......@@ -187,6 +189,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
......@@ -217,6 +220,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
......@@ -252,6 +256,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
......@@ -278,9 +283,11 @@ jobs:
keys:
- v0.4-tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]
- run: pip install tensorflow_probability
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
......@@ -312,9 +319,11 @@ jobs:
keys:
- v0.4-tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]
- run: pip install tensorflow_probability
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
......@@ -341,8 +350,10 @@ jobs:
keys:
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece,flax-speech,vision]
- run: pip install .[flax,testing,sentencepiece,flax-speech,vision]
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-flax-{{ checksum "setup.py" }}
paths:
......@@ -374,8 +385,10 @@ jobs:
keys:
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece,vision,flax-speech]
- run: pip install .[flax,testing,sentencepiece,vision,flax-speech]
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-flax-{{ checksum "setup.py" }}
paths:
......@@ -407,6 +420,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
......@@ -443,6 +457,7 @@ jobs:
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
......@@ -582,7 +597,7 @@ jobs:
path: ~/transformers/examples_output.txt
- store_artifacts:
path: ~/transformers/reports
run_examples_torch_all:
working_directory: ~/transformers
docker:
......
......@@ -34,6 +34,7 @@ jobs:
apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Launcher docker
uses: actions/checkout@v2
......@@ -87,6 +88,7 @@ jobs:
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade pip
pip install .[sklearn,testing,sentencepiece,flax,flax-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Launcher docker
uses: actions/checkout@v2
......@@ -142,6 +144,7 @@ jobs:
# apt -y update && apt install -y software-properties-common && apt -y update && add-apt-repository -y ppa:git-core/ppa && apt -y update && apt install -y git
# pip install --upgrade pip
# pip install .[sklearn,testing,onnxruntime,sentencepiece,tf-speech]
# pip install https://github.com/kpu/kenlm/archive/master.zip
#
# - name: Launcher docker
# uses: actions/checkout@v2
......@@ -200,7 +203,7 @@ jobs:
apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Launcher docker
uses: actions/checkout@v2
with:
......@@ -256,6 +259,7 @@ jobs:
# pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# pip install --upgrade pip
# pip install .[sklearn,testing,sentencepiece,flax,flax-speech,vision]
# pip install https://github.com/kpu/kenlm/archive/master.zip
#
# - name: Launcher docker
# uses: actions/checkout@v2
......@@ -311,6 +315,7 @@ jobs:
# apt -y update && apt install -y software-properties-common && apt -y update && add-apt-repository -y ppa:git-core/ppa && apt -y update && apt install -y git
# pip install --upgrade pip
# pip install .[sklearn,testing,onnxruntime,sentencepiece,tf-speech]
# pip install https://github.com/kpu/kenlm/archive/master.zip
#
# - name: Launcher docker
# uses: actions/checkout@v2
......
......@@ -36,6 +36,7 @@ jobs:
apt -y update && apt install -y libsndfile1-dev git
pip install --upgrade pip
pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -102,6 +103,7 @@ jobs:
pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install .[flax,integrations,sklearn,testing,sentencepiece,flax-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -141,6 +143,8 @@ jobs:
apt -y update && apt install -y libsndfile1-dev git
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -236,6 +240,7 @@ jobs:
apt -y update && apt install -y libsndfile1-dev git
pip install --upgrade pip
pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -288,6 +293,7 @@ jobs:
apt -y update && apt install -y libsndfile1-dev git
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
- name: Are GPUs recognized by our DL frameworks
run: |
......
......@@ -67,9 +67,19 @@ Wav2Vec2Processor
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
Wav2Vec2ProcessorWithLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ProcessorWithLM
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
Wav2Vec2 specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
:members:
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
:members:
......
......@@ -51,7 +51,7 @@ To create the package for pypi.
pip install -i https://testpypi.python.org/pypi transformers
Check you can run the following commands:
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from transformers import *"
9. Upload the final version to actual pypi:
......@@ -59,7 +59,7 @@ To create the package for pypi.
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
you need to go back to master before executing this.
"""
......@@ -159,6 +159,7 @@ _deps = [
"tokenizers>=0.10.1,<0.11",
"torch>=1.0",
"torchaudio",
"pyctcdecode>=0.2.0",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
......@@ -262,7 +263,7 @@ extras["sigopt"] = deps_list("sigopt")
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["audio"] = deps_list("librosa")
extras["audio"] = deps_list("librosa", "pyctcdecode")
extras["speech"] = deps_list("torchaudio") + extras["audio"] # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
extras["tf-speech"] = extras["audio"]
......
......@@ -44,6 +44,7 @@ from . import dependency_versions_check
from .file_utils import (
_LazyModule,
is_flax_available,
is_pyctcdecode_available,
is_pytorch_quantization_available,
is_scatter_available,
is_sentencepiece_available,
......@@ -471,6 +472,15 @@ else:
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
if is_pyctcdecode_available():
_import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM")
else:
from .utils import dummy_pyctcdecode_objects
_import_structure["utils.dummy_pyctcdecode_objects"] = [
name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_")
]
if is_sentencepiece_available() and is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
else:
......@@ -2441,6 +2451,11 @@ if TYPE_CHECKING:
else:
from .utils.dummy_speech_objects import *
if is_pyctcdecode_available():
from .models.wav2vec2 import Wav2Vec2ProcessorWithLM
else:
from .utils.dummy_pyctcdecode_objects import *
if is_speech_available() and is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor
else:
......
......@@ -70,6 +70,7 @@ deps = {
"tokenizers": "tokenizers>=0.10.1,<0.11",
"torch": "torch>=1.0",
"torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.2.0",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
......
......@@ -237,6 +237,22 @@ except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False
_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None
try:
_pyctcdecode_version = importlib_metadata.version("pyctcdecode")
logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}")
except importlib_metadata.PackageNotFoundError:
_pyctcdecode_available = False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
_librosa_version = importlib_metadata.version("librosa")
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
......@@ -311,6 +327,14 @@ def is_torch_available():
return _torch_available
def is_pyctcdecode_available():
return _pyctcdecode_available
def is_librosa_available():
return _librosa_available
def is_torch_cuda_available():
if is_torch_available():
import torch
......@@ -736,6 +760,12 @@ PYTESSERACT_IMPORT_ERROR = """
`pip install pytesseract`
"""
# docstyle-ignore
PYCTCDECODE_IMPORT_ERROR = """
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
`pip install pyctcdecode`
"""
BACKENDS_MAPPING = OrderedDict(
[
......@@ -745,6 +775,7 @@ BACKENDS_MAPPING = OrderedDict(
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -27,6 +27,9 @@ _import_structure = {
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
}
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
if is_torch_available():
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -61,6 +64,9 @@ if TYPE_CHECKING:
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
if is_torch_available():
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Speech processor class for Wav2Vec2
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import Pool
from typing import Iterable, List, Optional, Union
import numpy as np
from pyctcdecode import BeamSearchDecoderCTC
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
@dataclass
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2DecoderWithLM`, with transcription.
Args:
text (list of :obj:`str`):
Decoded logits in text from. Usually the speech transcription.
"""
text: Union[List[str], str]
class Wav2Vec2ProcessorWithLM:
r"""
Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor, a Wav2Vec2 CTC tokenizer and a decoder
with language model support into a single processor for language model boosted speech recognition decoding.
Args:
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
An instance of :class:`~transformers.Wav2Vec2FeatureExtractor`. The feature extractor is a required input.
tokenizer (:class:`~transformers.Wav2Vec2CTCTokenizer`):
An instance of :class:`~transformers.Wav2Vec2CTCTokenizer`. The tokenizer is a required input.
decoder (:obj:`pyctcdecode.BeamSearchDecoderCTC`):
An instance of :class:`pyctcdecode.BeamSearchDecoderCTC`. The decoder is a required input.
"""
def __init__(
self,
feature_extractor: FeatureExtractionMixin,
tokenizer: PreTrainedTokenizer,
decoder: BeamSearchDecoderCTC,
):
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
raise ValueError(
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
)
if not isinstance(tokenizer, Wav2Vec2CTCTokenizer):
# TODO(PVP) - this can be relaxed in the future to allow other kinds of tokenizers
raise ValueError(
f"`tokenizer` has to be of type {Wav2Vec2CTCTokenizer.__class__}, but is {type(tokenizer)}"
)
if not isinstance(decoder, BeamSearchDecoderCTC):
raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}")
# make sure that decoder's alphabet and tokenizer's vocab match in content
missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer)
if len(missing_decoder_tokens) > 0:
raise ValueError(
f"The tokens {missing_decoder_tokens} are defined in the tokenizer's "
"vocabulary, but not in the decoder's alphabet. "
f"Make sure to include {missing_decoder_tokens} in the decoder's alphabet."
)
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.decoder = decoder
self.current_processor = self.feature_extractor
def save_pretrained(self, save_directory):
"""
Save the Wav2Vec2 feature_extractor, a tokenizer object and a pyctcdecode decoder to the directory
``save_directory``, so that they can be re-loaded using the
:func:`~transformers.Wav2Vec2ProcessorWithLM.from_pretrained` class method.
.. note::
This class method is simply calling
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained,`
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained` and pyctcdecode's
:meth:`pyctcdecode.BeamSearchDecoderCTC.save_to_dir`.
Please refer to the docstrings of the methods above for more information.
Args:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist).
"""
self.feature_extractor.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
self.decoder.save_to_dir(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate a :class:`~transformers.Wav2Vec2ProcessorWithLM` from a pretrained Wav2Vec2 processor.
.. note::
This class method is simply calling Wav2Vec2FeatureExtractor's
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained`,
Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`,
and :meth:`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`.
Please refer to the docstrings of the methods above for more information.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/preprocessor_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`
"""
requires_backends(cls, "pyctcdecode")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
if os.path.isdir(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else:
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
# set language model attributes
for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]:
value = kwargs.pop(attribute, None)
if value is not None:
cls._set_language_model_attribute(decoder, attribute, value)
# make sure that decoder's alphabet and tokenizer's vocab match in content
missing_decoder_tokens = cls.get_missing_alphabet_tokens(decoder, tokenizer)
if len(missing_decoder_tokens) > 0:
raise ValueError(
f"The tokens {missing_decoder_tokens} are defined in the tokenizer's "
"vocabulary, but not in the decoder's alphabet. "
f"Make sure to include {missing_decoder_tokens} in the decoder's alphabet."
)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
@staticmethod
def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float):
setattr(decoder.model_container[decoder._model_key], attribute, value)
@property
def language_model(self):
return self.decoder.model_container[self.decoder._model_key]
@staticmethod
def get_missing_alphabet_tokens(decoder, tokenizer):
# we need to make sure that all of the tokenizer's except the special tokens
# are present in the decoder's alphabet. Retrieve missing alphabet token
# from decoder
tokenizer_vocab_list = list(tokenizer.get_vocab().keys())
# replace special tokens
for i, token in enumerate(tokenizer_vocab_list):
if BLANK_TOKEN_PTN.match(token):
tokenizer_vocab_list[i] = ""
if token == tokenizer.word_delimiter_token:
tokenizer_vocab_list[i] = " "
if UNK_TOKEN_PTN.match(token):
tokenizer_vocab_list[i] = UNK_TOKEN
# are any of the extra tokens no special tokenizer tokens?
missing_tokens = set(tokenizer_vocab_list) - set(decoder._alphabet.labels)
return missing_tokens
def __call__(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
:meth:`~transformers.Wav2Vec2FeatureExtractor.__call__` and returns its output. If used in the context
:meth:`~transformers.Wav2Vec2ProcessorWithLM.as_target_processor` this method forwards all its arguments to
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.__call__`. Please refer to the docstring of
the above two methods for more information.
"""
return self.current_processor(*args, **kwargs)
def pad(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context
:meth:`~transformers.Wav2Vec2ProcessorWithLM.as_target_processor` this method forwards all its arguments to
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the
above two methods for more information.
"""
return self.current_processor.pad(*args, **kwargs)
def batch_decode(
self,
logits: np.ndarray,
num_processes: Optional[int] = None,
beam_width: Optional[int] = None,
beam_prune_logp: Optional[float] = None,
token_min_logp: Optional[float] = None,
hotwords: Optional[Iterable[str]] = None,
hotword_weight: Optional[float] = None,
):
"""
Batch decode output logits to audio transcription with language model support.
.. note::
This function makes use of Python's multiprocessing.
Args:
logits (:obj:`np.ndarray`):
The logits output vector of the model representing the log probabilities for each token.
num_processes (:obj:`int`, `optional`):
Number of processes on which the function should be parallelized over. Defaults to the number of
available CPUs.
beam_width (:obj:`int`, `optional`):
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
beam_prune_logp (:obj:`int`, `optional`):
Beams that are much worse than best beam will be pruned Defaults to pyctcdecode's DEFAULT_PRUNE_LOGP.
token_min_logp (:obj:`int`, `optional`):
Tokens below this logp are skipped unless they are argmax of frame Defaults to pyctcdecode's
DEFAULT_MIN_TOKEN_LOGP.
hotwords (:obj:`List[str]`, `optional`):
List of words with extra importance, can be OOV for LM
hotword_weight (:obj:`int`, `optional`):
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
Returns:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
"""
# set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
beam_prune_logp = beam_prune_logp if beam_prune_logp is not None else DEFAULT_PRUNE_LOGP
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
# create multiprocessing pool and list numpy arrays
logits_list = [array for array in logits]
pool = Pool(num_processes)
# pyctcdecode
decoded_beams = self.decoder.decode_beams_batch(
pool,
logits_list=logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
hotwords=hotwords,
hotword_weight=hotword_weight,
)
# extract text
batch_texts = [d[0][0] for d in decoded_beams]
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
def decode(
self,
logits: np.ndarray,
beam_width: Optional[int] = None,
beam_prune_logp: Optional[float] = None,
token_min_logp: Optional[float] = None,
hotwords: Optional[Iterable[str]] = None,
hotword_weight: Optional[float] = None,
):
"""
Decode output logits to audio transcription with language model support.
Args:
logits (:obj:`np.ndarray`):
The logits output vector of the model representing the log probabilities for each token.
beam_width (:obj:`int`, `optional`):
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
beam_prune_logp (:obj:`int`, `optional`):
A threshold to prune beams with log-probs less than best_beam_logp + beam_prune_logp. The value should
be <= 0. Defaults to pyctcdecode's DEFAULT_PRUNE_LOGP.
token_min_logp (:obj:`int`, `optional`):
Tokens with log-probs below token_min_logp are skipped unless they are have the maximum log-prob for an
utterance. Defaults to pyctcdecode's DEFAULT_MIN_TOKEN_LOGP.
hotwords (:obj:`List[str]`, `optional`):
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
hotword_weight (:obj:`int`, `optional`):
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
Returns:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
"""
# set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
beam_prune_logp = beam_prune_logp if beam_prune_logp is not None else DEFAULT_PRUNE_LOGP
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
# pyctcdecode
decoded_beams = self.decoder.decode_beams(
logits,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
hotwords=hotwords,
hotword_weight=hotword_weight,
)
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
@contextmanager
def as_target_processor(self):
"""
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Wav2Vec2.
"""
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
......@@ -36,8 +36,10 @@ from .file_utils import (
is_faiss_available,
is_flax_available,
is_keras2onnx_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
is_pyctcdecode_available,
is_pytesseract_available,
is_pytorch_quantization_available,
is_rjieba_available,
......@@ -598,6 +600,26 @@ def require_deepspeed(test_case):
return test_case
def require_pyctcdecode(test_case):
"""
Decorator marking a test that requires pyctcdecode
"""
if not is_pyctcdecode_available():
return unittest.skip("test requires pyctcdecode")(test_case)
else:
return test_case
def require_librosa(test_case):
"""
Decorator marking a test that requires librosa
"""
if not is_librosa_available():
return unittest.skip("test requires librosa")(test_case)
else:
return test_case
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_backends
class Wav2Vec2ProcessorWithLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["pyctcdecode"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pyctcdecode"])
......@@ -17,9 +17,19 @@ import math
import unittest
import numpy as np
from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow
from transformers.testing_utils import (
is_librosa_available,
is_pyctcdecode_available,
require_datasets,
require_flax,
require_librosa,
require_pyctcdecode,
require_soundfile,
slow,
)
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
......@@ -39,6 +49,14 @@ if is_flax_available():
)
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
if is_librosa_available():
import librosa
class FlaxWav2Vec2ModelTester:
def __init__(
self,
......@@ -354,8 +372,6 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
@slow
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(
......@@ -447,3 +463,22 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
transcription = processor.batch_decode(np.array(logits)).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
......@@ -21,9 +21,11 @@ import unittest
import numpy as np
import pytest
from datasets import load_dataset
from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_datasets, require_librosa, require_pyctcdecode, require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
......@@ -36,6 +38,14 @@ if is_tf_available():
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
if is_librosa_available():
import librosa
@require_tf
class TFWav2Vec2ModelTester:
def __init__(
......@@ -474,7 +484,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@require_tf
@slow
@require_datasets
@require_soundfile
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
......@@ -544,3 +553,22 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="tf").input_values
logits = model(input_values).logits
transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
......@@ -18,15 +18,19 @@ import math
import unittest
import numpy as np
import pytest
from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
require_datasets,
require_pyctcdecode,
require_soundfile,
require_torch,
require_torchaudio,
slow,
torch_device,
)
......@@ -54,6 +58,14 @@ if is_torch_available():
)
if is_torchaudio_available():
import torchaudio
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
class Wav2Vec2ModelTester:
def __init__(
self,
......@@ -331,7 +343,7 @@ class Wav2Vec2ModelTester:
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
with pytest.raises(ValueError):
with self.parent.assertRaises(ValueError):
model(input_values, labels=labels)
def prepare_config_and_inputs_for_common(self):
......@@ -998,8 +1010,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(
......@@ -1009,8 +1019,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples]
def _load_superb(self, task, num_samples):
from datasets import load_dataset
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
......@@ -1337,3 +1345,27 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
transcription = processor.batch_decode(logits.cpu().numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import unittest
from multiprocessing import Pool
import numpy as np
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_pyctcdecode
from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
@require_pyctcdecode
class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
def setUp(self):
vocab = "| <pad> <unk> <s> </s> a b c d e f g h i j k".split()
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.add_kwargs_tokens_map = {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}
feature_extractor_map = {
"feature_size": 1,
"padding_value": 0.0,
"sampling_rate": 16000,
"return_attention_mask": False,
"do_normalize": True,
}
self.tmpdirname = tempfile.mkdtemp()
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
# load decoder from hub
self.decoder_name = "hf-internal-testing/ngram-beam-search-decoder"
def get_tokenizer(self, **kwargs_init):
kwargs = self.add_kwargs_tokens_map.copy()
kwargs.update(kwargs_init)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_feature_extractor(self, **kwargs):
return Wav2Vec2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
def get_decoder(self, **kwargs):
return BeamSearchDecoderCTC.load_from_hf_hub(self.decoder_name, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
processor.save_pretrained(self.tmpdirname)
processor = Wav2Vec2ProcessorWithLM.from_pretrained(self.tmpdirname)
# tokenizer
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
# feature extractor
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor)
# decoder
self.assertEqual(processor.decoder._alphabet.labels, decoder._alphabet.labels)
self.assertEqual(
processor.decoder.model_container[decoder._model_key]._unigram_set,
decoder.model_container[decoder._model_key]._unigram_set,
)
self.assertIsInstance(processor.decoder, BeamSearchDecoderCTC)
def test_save_load_pretrained_additional_features(self):
processor = Wav2Vec2ProcessorWithLM(
tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor(), decoder=self.get_decoder()
)
processor.save_pretrained(self.tmpdirname)
# make sure that error is thrown when decoder alphabet doesn't match
processor = Wav2Vec2ProcessorWithLM.from_pretrained(
self.tmpdirname, alpha=5.0, beta=3.0, score_boundary=-7.0, unk_score_offset=3
)
# decoder
self.assertEqual(processor.language_model.alpha, 5.0)
self.assertEqual(processor.language_model.beta, 3.0)
self.assertEqual(processor.language_model.score_boundary, -7.0)
self.assertEqual(processor.language_model.unk_score_offset, 3)
def test_load_decoder_tokenizer_mismatch_content(self):
tokenizer = self.get_tokenizer()
# add token to trigger raise
tokenizer.add_tokens(["xx"])
with self.assertRaisesRegex(ValueError, "include"):
Wav2Vec2ProcessorWithLM(
tokenizer=tokenizer, feature_extractor=self.get_feature_extractor(), decoder=self.get_decoder()
)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
input_str = "This is a test string"
with processor.as_target_processor():
encoded_processor = processor(input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def _get_dummy_logits(self, shape=(2, 10, 16), seed=77):
np.random.seed(seed)
return np.random.rand(*shape)
def test_decoder(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
decoded_processor = processor.decode(logits).text
decoded_decoder = decoder.decode_beams(logits)[0][0]
self.assertEqual(decoded_decoder, decoded_processor)
self.assertEqual("</s> <s> </s>", decoded_processor)
def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits).text
logits_list = [array for array in logits]
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
logits = self._get_dummy_logits()
beam_width = 20
beam_prune_logp = -20.0
token_min_logp = -4.0
decoded_processor_out = processor.batch_decode(
logits,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
decoded_decoder_out = decoder.decode_beams_batch(
Pool(),
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment