"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "592f2eabd17cbdebd13dec54edf412f9f8232152"
Unverified Commit af150e4a authored by Antonio Carlos Falcão Petri's avatar Antonio Carlos Falcão Petri Committed by GitHub
Browse files

Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)



* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode

* [Wav2Vec2] Add user-managed LM's pool tests and usage examples

* Improve styling
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* [Wav2Vec2] Fix hyperlink references
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent bf0e0941
...@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv ...@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- batch_decode - batch_decode
- decode - decode
### Decoding multiple audios
If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
```python
>>> # Let's see how to use a user-managed pool for batch decoding multiple audios
>>> from multiprocessing import get_context
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
>>> from datasets import load_dataset
>>> import datasets
>>> import torch
>>> # import model, feature extractor, tokenizer
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> # load example dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> def map_to_array(batch):
... batch["speech"] = batch["audio"]["array"]
... return batch
>>> # prepare speech data for batch inference
>>> dataset = dataset.map(map_to_array, remove_columns=["audio"])
>>> def map_to_pred(batch, pool):
... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt")
... inputs = {k: v.to("cuda") for k, v in inputs.items()}
... with torch.no_grad():
... logits = model(**inputs).logits
... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
... batch["transcription"] = transcription
... return batch
>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`.
>>> # otherwise, the LM won't be available to the pool's sub-processes
>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size
>>> with get_context("fork").Pool(processes=2) as pool:
... result = dataset.map(
... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"]
... )
>>> result["transcription"][:2]
['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"]
```
## Wav2Vec2 specific outputs ## Wav2Vec2 specific outputs
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput [[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
......
...@@ -164,7 +164,7 @@ _deps = [ ...@@ -164,7 +164,7 @@ _deps = [
"tokenizers>=0.11.1,!=0.11.3,<0.14", "tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch>=1.7,!=1.12.0", "torch>=1.7,!=1.12.0",
"torchaudio", "torchaudio",
"pyctcdecode>=0.3.0", "pyctcdecode>=0.4.0",
"tqdm>=4.27", "tqdm>=4.27",
"unidic>=1.0.2", "unidic>=1.0.2",
"unidic_lite>=1.0.7", "unidic_lite>=1.0.7",
......
...@@ -70,7 +70,7 @@ deps = { ...@@ -70,7 +70,7 @@ deps = {
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14", "tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch": "torch>=1.7,!=1.12.0", "torch": "torch>=1.7,!=1.12.0",
"torchaudio": "torchaudio", "torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.3.0", "pyctcdecode": "pyctcdecode>=0.4.0",
"tqdm": "tqdm>=4.27", "tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2", "unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7", "unidic_lite": "unidic_lite>=1.0.7",
......
...@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip> <Tip>
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
understand how to make use of `output_word_offsets`. use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. output.
</Tip> </Tip>
...@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip> <Tip>
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
understand how to make use of `output_word_offsets`. use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. output.
</Tip> </Tip>
...@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip> <Tip>
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better Please take a look at the example below to better understand how to make use of `output_char_offsets`.
understand how to make use of `output_word_offsets`.
</Tip> </Tip>
...@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip> <Tip>
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better Please take a look at the example below to better understand how to make use of `output_word_offsets`.
understand how to make use of `output_word_offsets`.
</Tip> </Tip>
......
...@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2 ...@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2
""" """
import os import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import get_context from multiprocessing import Pool, get_context, get_start_method
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
import numpy as np import numpy as np
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...utils import ModelOutput, requires_backends from ...utils import ModelOutput, logging, requires_backends
logger = logging.get_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
This class method is simply calling Wav2Vec2FeatureExtractor's This class method is simply calling Wav2Vec2FeatureExtractor's
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`]. [`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
Please refer to the docstrings of the methods above for more information. Please refer to the docstrings of the methods above for more information.
...@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
def batch_decode( def batch_decode(
self, self,
logits: np.ndarray, logits: np.ndarray,
pool: Optional[Pool] = None,
num_processes: Optional[int] = None, num_processes: Optional[int] = None,
beam_width: Optional[int] = None, beam_width: Optional[int] = None,
beam_prune_logp: Optional[float] = None, beam_prune_logp: Optional[float] = None,
...@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip> <Tip>
This function makes use of Python's multiprocessing. This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix
systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)).
If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise,
`batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below.
</Tip> </Tip>
Args: Args:
logits (`np.ndarray`): logits (`np.ndarray`):
The logits output vector of the model representing the log probabilities for each token. The logits output vector of the model representing the log probabilities for each token.
pool (`multiprocessing.Pool`, *optional*):
An optional user-managed pool. If not set, one will be automatically created and closed. The pool
should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the
pool's sub-processes.
<Tip>
Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will
be ignored and sequential decoding will be used instead.
</Tip>
num_processes (`int`, *optional*): num_processes (`int`, *optional*):
Number of processes on which the function should be parallelized over. Defaults to the number of If `pool` is not set, number of processes on which the function should be parallelized over. Defaults
available CPUs. to the number of available CPUs.
beam_width (`int`, *optional*): beam_width (`int`, *optional*):
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH. Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
beam_prune_logp (`int`, *optional*): beam_prune_logp (`int`, *optional*):
...@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip> <Tip>
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to
better understand how to make use of `output_word_offsets`. make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched batched output.
output.
</Tip> </Tip>
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
Example:
See [Decoding multiple audios](#decoding-multiple-audios).
""" """
from pyctcdecode.constants import ( from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH, DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT, DEFAULT_HOTWORD_WEIGHT,
...@@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# create multiprocessing pool and list numpy arrays # create multiprocessing pool and list numpy arrays
# filter out logits padding # filter out logits padding
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits] logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
pool = get_context("fork").Pool(num_processes)
# pyctcdecode # create a pool if necessary while also using it as a context manager to close itself
decoded_beams = self.decoder.decode_beams_batch( if pool is None:
pool, # fork is safe to use only on Unix, see "Contexts and start methods" section on
logits_list=logits_list, # multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)
beam_width=beam_width, default_context = get_start_method()
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp, if default_context == "fork":
hotwords=hotwords, cm = pool = get_context().Pool(num_processes)
hotword_weight=hotword_weight, else:
) logger.warning(
"Parallel batch decoding is not currently supported in this platform. "
"Falling back to sequential decoding."
)
cm = nullcontext()
else:
# pool is managed by the user, so we don't need to close it
cm = nullcontext()
if num_processes is not None:
logger.warning(
"Parameter `num_process` was passed, but it will be ignored since `pool` was also specified."
)
# clone multi-processing pool # pyctcdecode
pool.close() with cm:
decoded_beams = self.decoder.decode_beams_batch(
pool=pool,
logits_list=logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
hotwords=hotwords,
hotword_weight=hotword_weight,
)
# extract text and scores # extract text and scores
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], [] batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
...@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip> <Tip>
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to Please take a look at the example below to better understand how to make use of `output_word_offsets`.
better understand how to make use of `output_word_offsets`.
</Tip> </Tip>
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
Example: Example:
......
...@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin): ...@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin):
<Tip> <Tip>
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
[`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
above for more information. methods above for more information.
</Tip> </Tip>
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import inspect import inspect
import math import math
import multiprocessing
import unittest import unittest
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ from datasets import load_dataset ...@@ -21,6 +22,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger,
is_flaky, is_flaky,
is_librosa_available, is_librosa_available,
is_pt_flax_cross_test, is_pt_flax_cross_test,
...@@ -53,6 +55,7 @@ if is_flax_available(): ...@@ -53,6 +55,7 @@ if is_flax_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available(): if is_librosa_available():
...@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(np.array(logits)).text transcription = processor.batch_decode(np.array(logits)).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import glob import glob
import inspect import inspect
import math import math
import multiprocessing
import unittest import unittest
import numpy as np import numpy as np
...@@ -26,7 +27,7 @@ from datasets import load_dataset ...@@ -26,7 +27,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -42,6 +43,7 @@ if is_tf_available(): ...@@ -42,6 +43,7 @@ if is_tf_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available(): if is_librosa_available():
...@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(logits.numpy()).text transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_pool(self):
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_invalid_pool(self):
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """ """ Testing suite for the PyTorch Wav2Vec2 model. """
import math import math
import multiprocessing
import os import os
import pickle import pickle
import tempfile import tempfile
...@@ -25,6 +26,7 @@ from datasets import load_dataset ...@@ -25,6 +26,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_torch_available from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pyctcdecode_available, is_pyctcdecode_available,
is_torchaudio_available, is_torchaudio_available,
...@@ -74,6 +76,7 @@ if is_torchaudio_available(): ...@@ -74,6 +76,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_torch_fx_available(): if is_torch_fx_available():
...@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
def test_inference_diarization(self): def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd") processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
......
...@@ -25,6 +25,7 @@ import numpy as np ...@@ -25,6 +25,7 @@ import numpy as np
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from parameterized import parameterized
from transformers import AutoProcessor from transformers import AutoProcessor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
...@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score) self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score) self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self): @parameterized.expand([[None], ["fork"], ["spawn"]])
def test_decoder_batch(self, pool_context):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
decoder = self.get_decoder() decoder = self.get_decoder()
...@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits() logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits) # note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
# otherwise, the LM won't be available to the pool's sub-processes.
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
if pool_context is None:
decoded_processor = processor.batch_decode(logits)
else:
with get_context(pool_context).Pool() as pool:
decoded_processor = processor.batch_decode(logits, pool)
logits_list = [array for array in logits] logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_beams = decoder.decode_beams_batch(pool, logits_list) with get_context("fork").Pool() as p:
decoded_beams = decoder.decode_beams_batch(p, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], [] texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams: for beams in decoded_beams:
texts_decoder.append(beams[0][0]) texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2]) logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1]) lm_scores_decoder.append(beams[0][-1])
pool.close()
self.assertListEqual(texts_decoder, decoded_processor.text) self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text) self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
...@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits] logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch( with get_context("fork").Pool() as pool:
pool, decoded_decoder_out = decoder.decode_beams_batch(
logits_list, pool,
beam_width=beam_width, logits_list,
beam_prune_logp=beam_prune_logp, beam_width=beam_width,
token_min_logp=token_min_logp, beam_prune_logp=beam_prune_logp,
) token_min_logp=token_min_logp,
pool.close() )
decoded_decoder = [d[0][0] for d in decoded_decoder_out] decoded_decoder = [d[0][0] for d in decoded_decoder_out]
...@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
unk_score_offset=unk_score_offset, unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary, lm_score_boundary=lm_score_boundary,
) )
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch( with get_context("fork").Pool() as pool:
pool, decoded_decoder_out = decoder.decode_beams_batch(
logits_list, pool,
) logits_list,
pool.close() )
decoded_decoder = [d[0][0] for d in decoded_decoder_out] decoded_decoder = [d[0][0] for d in decoded_decoder_out]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment