Unverified Commit af150e4a authored by Antonio Carlos Falcão Petri's avatar Antonio Carlos Falcão Petri Committed by GitHub
Browse files

Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)



* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode

* [Wav2Vec2] Add user-managed LM's pool tests and usage examples

* Improve styling
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* [Wav2Vec2] Fix hyperlink references
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent bf0e0941
......@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- batch_decode
- decode
### Decoding multiple audios
If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
```python
>>> # Let's see how to use a user-managed pool for batch decoding multiple audios
>>> from multiprocessing import get_context
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
>>> from datasets import load_dataset
>>> import datasets
>>> import torch
>>> # import model, feature extractor, tokenizer
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> # load example dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> def map_to_array(batch):
... batch["speech"] = batch["audio"]["array"]
... return batch
>>> # prepare speech data for batch inference
>>> dataset = dataset.map(map_to_array, remove_columns=["audio"])
>>> def map_to_pred(batch, pool):
... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt")
... inputs = {k: v.to("cuda") for k, v in inputs.items()}
... with torch.no_grad():
... logits = model(**inputs).logits
... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
... batch["transcription"] = transcription
... return batch
>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`.
>>> # otherwise, the LM won't be available to the pool's sub-processes
>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size
>>> with get_context("fork").Pool(processes=2) as pool:
... result = dataset.map(
... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"]
... )
>>> result["transcription"][:2]
['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"]
```
## Wav2Vec2 specific outputs
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
......
......@@ -164,7 +164,7 @@ _deps = [
"tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch>=1.7,!=1.12.0",
"torchaudio",
"pyctcdecode>=0.3.0",
"pyctcdecode>=0.4.0",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
......
......@@ -70,7 +70,7 @@ deps = {
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch": "torch>=1.7,!=1.12.0",
"torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.3.0",
"pyctcdecode": "pyctcdecode>=0.4.0",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
......
......@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip>
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
understand how to make use of `output_word_offsets`.
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
output.
</Tip>
......@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip>
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
understand how to make use of `output_word_offsets`.
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
output.
</Tip>
......@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip>
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
understand how to make use of `output_word_offsets`.
Please take a look at the example below to better understand how to make use of `output_char_offsets`.
</Tip>
......@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
<Tip>
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
understand how to make use of `output_word_offsets`.
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
</Tip>
......
......@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2
"""
import os
import warnings
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import get_context
from multiprocessing import Pool, get_context, get_start_method
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
import numpy as np
from ...processing_utils import ProcessorMixin
from ...utils import ModelOutput, requires_backends
from ...utils import ModelOutput, logging, requires_backends
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
......@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
This class method is simply calling Wav2Vec2FeatureExtractor's
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and
[`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
Please refer to the docstrings of the methods above for more information.
......@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
def batch_decode(
self,
logits: np.ndarray,
pool: Optional[Pool] = None,
num_processes: Optional[int] = None,
beam_width: Optional[int] = None,
beam_prune_logp: Optional[float] = None,
......@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip>
This function makes use of Python's multiprocessing.
This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix
systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)).
If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise,
`batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below.
</Tip>
Args:
logits (`np.ndarray`):
The logits output vector of the model representing the log probabilities for each token.
pool (`multiprocessing.Pool`, *optional*):
An optional user-managed pool. If not set, one will be automatically created and closed. The pool
should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the
pool's sub-processes.
<Tip>
Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will
be ignored and sequential decoding will be used instead.
</Tip>
num_processes (`int`, *optional*):
Number of processes on which the function should be parallelized over. Defaults to the number of
available CPUs.
If `pool` is not set, number of processes on which the function should be parallelized over. Defaults
to the number of available CPUs.
beam_width (`int`, *optional*):
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
beam_prune_logp (`int`, *optional*):
......@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip>
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
output.
Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to
make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with
batched output.
</Tip>
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
Example:
See [Decoding multiple audios](#decoding-multiple-audios).
"""
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
......@@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# create multiprocessing pool and list numpy arrays
# filter out logits padding
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
pool = get_context("fork").Pool(num_processes)
# pyctcdecode
decoded_beams = self.decoder.decode_beams_batch(
pool,
logits_list=logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
hotwords=hotwords,
hotword_weight=hotword_weight,
)
# create a pool if necessary while also using it as a context manager to close itself
if pool is None:
# fork is safe to use only on Unix, see "Contexts and start methods" section on
# multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)
default_context = get_start_method()
if default_context == "fork":
cm = pool = get_context().Pool(num_processes)
else:
logger.warning(
"Parallel batch decoding is not currently supported in this platform. "
"Falling back to sequential decoding."
)
cm = nullcontext()
else:
# pool is managed by the user, so we don't need to close it
cm = nullcontext()
if num_processes is not None:
logger.warning(
"Parameter `num_process` was passed, but it will be ignored since `pool` was also specified."
)
# clone multi-processing pool
pool.close()
# pyctcdecode
with cm:
decoded_beams = self.decoder.decode_beams_batch(
pool=pool,
logits_list=logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
hotwords=hotwords,
hotword_weight=hotword_weight,
)
# extract text and scores
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
......@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
<Tip>
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
</Tip>
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
Example:
......
......@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin):
<Tip>
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
[`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods
above for more information.
[`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
methods above for more information.
</Tip>
......
......@@ -14,6 +14,7 @@
import inspect
import math
import multiprocessing
import unittest
import numpy as np
......@@ -21,6 +22,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
......@@ -53,6 +55,7 @@ if is_flax_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available():
......@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(np.array(logits)).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
......@@ -18,6 +18,7 @@ import copy
import glob
import inspect
import math
import multiprocessing
import unittest
import numpy as np
......@@ -26,7 +27,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester
......@@ -42,6 +43,7 @@ if is_tf_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available():
......@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_pool(self):
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm_invalid_pool(self):
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
......@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """
import math
import multiprocessing
import os
import pickle
import tempfile
......@@ -25,6 +26,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
CaptureLogger,
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
......@@ -74,6 +76,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_torch_fx_available():
......@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
......
......@@ -25,6 +25,7 @@ import numpy as np
from datasets import load_dataset
from packaging import version
from parameterized import parameterized
from transformers import AutoProcessor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
......@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
@parameterized.expand([[None], ["fork"], ["spawn"]])
def test_decoder_batch(self, pool_context):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
......@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits)
# note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
# otherwise, the LM won't be available to the pool's sub-processes.
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
if pool_context is None:
decoded_processor = processor.batch_decode(logits)
else:
with get_context(pool_context).Pool() as pool:
decoded_processor = processor.batch_decode(logits, pool)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
with get_context("fork").Pool() as p:
decoded_beams = decoder.decode_beams_batch(p, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
pool.close()
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
......@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
pool.close()
with get_context("fork").Pool() as pool:
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
......@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
pool.close()
with get_context("fork").Pool() as pool:
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment