Unverified Commit 13186d71 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Move pyctcdecode (#14686)

* Move pyctcdecode dep

* Fix doc and last objects

* Quality

* Style

* Ignore this black
parent d104dd46
...@@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM ...@@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM
Wav2Vec2 specific outputs Wav2Vec2 specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput .. autoclass:: transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
:members: :members:
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput .. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
......
...@@ -313,7 +313,7 @@ _import_structure = { ...@@ -313,7 +313,7 @@ _import_structure = {
"Wav2Vec2Processor", "Wav2Vec2Processor",
"Wav2Vec2Tokenizer", "Wav2Vec2Tokenizer",
], ],
"models.wav2vec2_with_lm": [], "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
...@@ -475,15 +475,6 @@ else: ...@@ -475,15 +475,6 @@ else:
name for name in dir(dummy_speech_objects) if not name.startswith("_") name for name in dir(dummy_speech_objects) if not name.startswith("_")
] ]
if is_pyctcdecode_available():
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
else:
from .utils import dummy_pyctcdecode_objects
_import_structure["utils.dummy_pyctcdecode_objects"] = [
name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_")
]
if is_sentencepiece_available() and is_speech_available(): if is_sentencepiece_available() and is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor") _import_structure["models.speech_to_text"].append("Speech2TextProcessor")
else: else:
...@@ -2329,6 +2320,7 @@ if TYPE_CHECKING: ...@@ -2329,6 +2320,7 @@ if TYPE_CHECKING:
Wav2Vec2Processor, Wav2Vec2Processor,
Wav2Vec2Tokenizer, Wav2Vec2Tokenizer,
) )
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
...@@ -2472,11 +2464,6 @@ if TYPE_CHECKING: ...@@ -2472,11 +2464,6 @@ if TYPE_CHECKING:
else: else:
from .utils.dummy_speech_objects import * from .utils.dummy_speech_objects import *
if is_pyctcdecode_available():
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
from .utils.dummy_pyctcdecode_objects import *
if is_speech_available() and is_sentencepiece_available(): if is_speech_available() and is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor from .models.speech_to_text import Speech2TextProcessor
else: else:
......
...@@ -17,19 +17,18 @@ ...@@ -17,19 +17,18 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_pyctcdecode_available from ...file_utils import _LazyModule
_import_structure = {"processing_wav2vec2_with_lm": []} # fmt: off
_import_structure = {
"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
if is_pyctcdecode_available(): }
_import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM") # fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
if is_pyctcdecode_available(): from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else: else:
import sys import sys
......
...@@ -19,19 +19,10 @@ import os ...@@ -19,19 +19,10 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Pool from multiprocessing import Pool
from typing import Iterable, List, Optional, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import numpy as np import numpy as np
from pyctcdecode import BeamSearchDecoderCTC
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
...@@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor ...@@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC
@dataclass @dataclass
class Wav2Vec2DecoderWithLMOutput(ModelOutput): class Wav2Vec2DecoderWithLMOutput(ModelOutput):
""" """
...@@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM: ...@@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM:
self, self,
feature_extractor: FeatureExtractionMixin, feature_extractor: FeatureExtractionMixin,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
decoder: BeamSearchDecoderCTC, decoder: "BeamSearchDecoderCTC",
): ):
from pyctcdecode import BeamSearchDecoderCTC
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor): if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
raise ValueError( raise ValueError(
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}" f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
...@@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM: ...@@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`
""" """
requires_backends(cls, "pyctcdecode") requires_backends(cls, "pyctcdecode")
from pyctcdecode import BeamSearchDecoderCTC
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
...@@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM: ...@@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM:
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
@staticmethod @staticmethod
def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float): def _set_language_model_attribute(decoder: "BeamSearchDecoderCTC", attribute: str, value: float):
setattr(decoder.model_container[decoder._model_key], attribute, value) setattr(decoder.model_container[decoder._model_key], attribute, value)
@property @property
...@@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM: ...@@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM:
@staticmethod @staticmethod
def get_missing_alphabet_tokens(decoder, tokenizer): def get_missing_alphabet_tokens(decoder, tokenizer):
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
# we need to make sure that all of the tokenizer's except the special tokens # we need to make sure that all of the tokenizer's except the special tokens
# are present in the decoder's alphabet. Retrieve missing alphabet token # are present in the decoder's alphabet. Retrieve missing alphabet token
# from decoder # from decoder
...@@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM: ...@@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
""" """
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
# set defaults # set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
...@@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM: ...@@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM:
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
""" """
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
# set defaults # set defaults
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_backends
class Wav2Vec2ProcessorWithLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["pyctcdecode"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pyctcdecode"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment