"tests/vscode:/vscode.git/clone" did not exist on "a7920065f2cfd2549b838f9a30afd7c265fcdd88"
Unverified Commit 13186d71 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Move pyctcdecode (#14686)

* Move pyctcdecode dep

* Fix doc and last objects

* Quality

* Style

* Ignore this black
parent d104dd46
......@@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM
Wav2Vec2 specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
.. autoclass:: transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
:members:
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
......
......@@ -313,7 +313,7 @@ _import_structure = {
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
"models.wav2vec2_with_lm": [],
"models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
......@@ -475,15 +475,6 @@ else:
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
if is_pyctcdecode_available():
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
else:
from .utils import dummy_pyctcdecode_objects
_import_structure["utils.dummy_pyctcdecode_objects"] = [
name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_")
]
if is_sentencepiece_available() and is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
else:
......@@ -2329,6 +2320,7 @@ if TYPE_CHECKING:
Wav2Vec2Processor,
Wav2Vec2Tokenizer,
)
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
......@@ -2472,11 +2464,6 @@ if TYPE_CHECKING:
else:
from .utils.dummy_speech_objects import *
if is_pyctcdecode_available():
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
from .utils.dummy_pyctcdecode_objects import *
if is_speech_available() and is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor
else:
......
......@@ -17,19 +17,18 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_pyctcdecode_available
from ...file_utils import _LazyModule
_import_structure = {"processing_wav2vec2_with_lm": []}
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
# fmt: off
_import_structure = {
"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
}
# fmt: on
if TYPE_CHECKING:
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
import sys
......
......@@ -19,19 +19,10 @@ import os
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import Pool
from typing import Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import numpy as np
from pyctcdecode import BeamSearchDecoderCTC
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer
......@@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC
@dataclass
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
"""
......@@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM:
self,
feature_extractor: FeatureExtractionMixin,
tokenizer: PreTrainedTokenizer,
decoder: BeamSearchDecoderCTC,
decoder: "BeamSearchDecoderCTC",
):
from pyctcdecode import BeamSearchDecoderCTC
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
raise ValueError(
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
......@@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.PreTrainedTokenizer`
"""
requires_backends(cls, "pyctcdecode")
from pyctcdecode import BeamSearchDecoderCTC
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
......@@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM:
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
@staticmethod
def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float):
def _set_language_model_attribute(decoder: "BeamSearchDecoderCTC", attribute: str, value: float):
setattr(decoder.model_container[decoder._model_key], attribute, value)
@property
......@@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM:
@staticmethod
def get_missing_alphabet_tokens(decoder, tokenizer):
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
# we need to make sure that all of the tokenizer's except the special tokens
# are present in the decoder's alphabet. Retrieve missing alphabet token
# from decoder
......@@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
"""
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
# set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
......@@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
"""
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
# set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_backends
class Wav2Vec2ProcessorWithLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["pyctcdecode"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pyctcdecode"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment