Unverified Commit f7f0ec2f authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `fp16` for asr pipeline. (#20864)

* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fcb520e5f34d1933d9d37d8f32b64553048.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
parent 15bc776f
......@@ -874,6 +874,9 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
if torch_dtype is not None:
kwargs["torch_dtype"] = torch_dtype
if device is not None:
kwargs["device"] = device
......
......@@ -52,13 +52,15 @@ def rescale_stride(stride, ratio):
return new_strides
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
......@@ -240,6 +242,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
if "torch_dtype" in kwargs:
preprocess_params["dtype"] = kwargs["torch_dtype"]
postprocess_params = {}
if "decoder_kwargs" in kwargs:
......@@ -249,7 +253,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
......@@ -332,12 +336,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("Chunk length must be superior to stride length")
# make sure that
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right):
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = processed.to(dtype=dtype)
if stride is not None:
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
raise ValueError("Stride is only usable with CTC models, try removing it")
......@@ -366,6 +372,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
......
......@@ -145,6 +145,19 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
_ = speech_recognizer(waveform, return_timestamps="char")
@slow
@require_torch
def test_whisper_fp16(self):
if not torch.cuda.is_available():
self.skipTest("Cuda is necessary for this test")
speech_recognizer = pipeline(
model="openai/whisper-base",
device=0,
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
speech_recognizer(waveform)
@require_torch
def test_small_model_pt_seq2seq(self):
speech_recognizer = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment