Commit 0fdb72ac authored by Attila Dusnoki's avatar Attila Dusnoki
Browse files

Add Whisper example

parent e3e00547
# Transformers Inference Examples
- [Python Whisper](./python_whisper)
# Whisper
This version was tested with [rocm 5.7](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0) revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See [whisper.ipynb](./whisper.ipynb)
## Console application
To run the console application, follow these steps below.
Setup python environment
```bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3 -m venv w_venv
. w_venv/bin/activate
```
Install dependencies
`ffmpeg` needed to handle audio files.
```bash
apt install ffmpeg
```
```bash
pip install -r requirements.txt
```
Use MIGraphX Python Module
```bash
export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH
```
Use the helper script to download with optimum.
The attention_mask for decoder is not exposed by default, but required to work with MIGraphX.
```bash
python download_whisper.py
```
*Note: `models/whisper-tiny.en_modified` will be used in the scripts*
There are *optional* samples which can be downloaded. But the example can be tested without them.
```bash
./download_samples.sh
```
Run the automatic-speech-recognition script with the following example input:
```bash
python asr.py --audio audio/sample1.flac --log-process
```
Or without any audio input to run the [Hugging Face dummy dataset](https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy) samples.
## Gradio application
Note: requires `Console application` to work
Install gradio dependencies
```bash
pip install -r gradio_requirements.txt
```
Usage
```bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on `http://127.0.0.1:7860`.
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from argparse import ArgumentParser
from transformers import WhisperProcessor
from datasets import load_dataset
from pydub import AudioSegment
import migraphx as mgx
import os
import numpy as np
import time
from functools import wraps
# measurement helper
def measure(fn):
@wraps(fn)
def measure_ms(*args, **kwargs):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(
f"Elapsed time for {fn.__name__}: {(end_time - start_time) * 1e-6:.4f} ms\n"
)
return result
return measure_ms
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-a",
"--audio",
type=str,
help="Path to audio file. Default: HF test dataset",
)
parser.add_argument(
"-l",
"--log-process",
action="store_true",
help="Print the current state of transcribing.",
)
return parser.parse_args()
class WhisperMGX():
def __init__(self):
model_id = "openai/whisper-tiny.en"
print(f"Using {model_id}")
print("Creating Whisper processor")
self.processor = WhisperProcessor.from_pretrained(model_id)
self.decoder_start_token_id = 50257 # <|startoftranscript|>
self.eos_token_id = 50256 # "<|endoftext|>"
self.notimestamps = 50362 # <|notimestamps|>
self.max_length = 448
self.sot = [self.decoder_start_token_id, self.notimestamps]
print("Load models...")
self.encoder_model = WhisperMGX.load_mgx_model(
"encoder", {"input_features": [1, 80, 3000]})
self.decoder_model = WhisperMGX.load_mgx_model(
"decoder", {
"input_ids": [1, self.max_length],
"attention_mask": [1, self.max_length],
"encoder_hidden_states": [1, 1500, 384]
})
@staticmethod
@measure
def load_audio_from_file(filepath):
audio = AudioSegment.from_file(filepath)
# Only 16k is supported
audio = audio.set_frame_rate(16000)
data = np.array(audio.get_array_of_samples(), dtype=np.float32)
data /= np.max(np.abs(data))
return data, audio.frame_rate
@staticmethod
@measure
def load_mgx_model(name, shapes):
file = f"models/whisper-tiny.en_modified/{name}_model"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
return model
@property
def initial_decoder_inputs(self):
input_ids = np.array([
self.sot + [self.eos_token_id] * (self.max_length - len(self.sot))
])
# 0 masked | 1 un-masked
attention_mask = np.array([[1] * len(self.sot) + [0] *
(self.max_length - len(self.sot))])
return (input_ids, attention_mask)
@measure
def get_input_features_from_sample(self, sample_data, sampling_rate):
return self.processor(sample_data,
sampling_rate=sampling_rate,
return_tensors="np").input_features
@measure
def encode_features(self, input_features):
return np.array(
self.encoder_model.run(
{"input_features": input_features.astype(np.float32)})[0])
def decode_step(self, input_ids, attention_mask, hidden_states):
return np.array(
self.decoder_model.run({
"input_ids":
input_ids.astype(np.int64),
"attention_mask":
attention_mask.astype(np.int64),
"encoder_hidden_states":
hidden_states.astype(np.float32)
})[0])
@measure
def generate(self, input_features, log_process=False):
hidden_states = self.encode_features(input_features)
input_ids, attention_mask = self.initial_decoder_inputs
for timestep in range(len(self.sot) - 1, self.max_length):
# get logits for the current timestep
logits = self.decode_step(input_ids, attention_mask, hidden_states)
# greedily get the highest probable token
new_token = np.argmax(logits[0][timestep])
# add it to the tokens and unmask it
input_ids[0][timestep + 1] = new_token
attention_mask[0][timestep + 1] = 1
if log_process:
print("Transcribing: " + ''.join(
self.processor.decode(input_ids[0][:timestep + 1],
skip_special_tokens=True)),
end='\r')
if new_token == self.eos_token_id:
break
if log_process:
print(flush=True)
return ''.join(
self.processor.decode(input_ids[0][:timestep + 1],
skip_special_tokens=True))
if __name__ == "__main__":
args = get_args()
if args.audio:
data, fr = WhisperMGX.load_audio_from_file(args.audio)
ds = [{"audio": {"array": data, "sampling_rate": fr}}]
else:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy",
"clean",
split="validation")
w = WhisperMGX()
for idx, data in enumerate(ds):
print(f"#{idx+1}/{len(ds)} Sample...")
sample = data["audio"]
input_features = w.get_input_features_from_sample(
sample["array"], sample["sampling_rate"])
result = w.generate(input_features, log_process=args.log_process)
print(f"Result: {result}")
#!/bin/bash
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
DIR="$(dirname "$0")/audio/"
mkdir -p $DIR
wget -q https://cdn-media.huggingface.co/speech_samples/sample1.flac -O "$DIR/sample1.flac"
wget -q https://cdn-media.huggingface.co/speech_samples/sample2.flac -O "$DIR/sample2.flac"
echo "Samples downloaded to $DIR"
\ No newline at end of file
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig
from optimum.exporters.onnx.base import ConfigBehavior
from typing import Dict
class CustomWhisperOnnxConfig(WhisperOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}
if self._behavior is ConfigBehavior.ENCODER:
common_inputs["input_features"] = {
0: "batch_size",
1: "feature_size",
2: "encoder_sequence_length"
}
if self._behavior is ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {
0: "batch_size",
1: "decoder_sequence_length"
}
common_inputs["decoder_attention_mask"] = {
0: "batch_size",
1: "decoder_sequence_length"
}
common_inputs["encoder_outputs"] = {
0: "batch_size",
1: "encoder_sequence_length"
}
return common_inputs
@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
if self._behavior is ConfigBehavior.DECODER:
return {
"decoder_input_ids": "input_ids",
"decoder_attention_mask": "attention_mask",
"encoder_outputs": "encoder_hidden_states",
}
return {}
def export():
model_id = "openai/whisper-tiny.en"
config = AutoConfig.from_pretrained(model_id)
custom_whisper_onnx_config = CustomWhisperOnnxConfig(
config=config,
task="automatic-speech-recognition",
)
encoder_config = custom_whisper_onnx_config.with_behavior("encoder")
decoder_config = custom_whisper_onnx_config.with_behavior("decoder",
use_past=False)
custom_onnx_configs = {
"encoder_model": encoder_config,
"decoder_model": decoder_config,
}
output = "models/whisper-tiny.en_modified"
main_export(model_id,
output=output,
no_post_process=True,
do_validation=False,
custom_onnx_configs=custom_onnx_configs)
print(f"Done. Check {output}")
if __name__ == "__main__":
export()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from asr import WhisperMGX
import gradio as gr
import os
def main():
# Note: This will load the models, which can take several minutes
w = WhisperMGX()
def gr_wrapper(audio):
data, fr = WhisperMGX.load_audio_from_file(audio)
input_features = w.get_input_features_from_sample(data, fr)
return w.generate(input_features)
examples = [
os.path.join(os.path.dirname(__file__), "audio/sample1.flac"),
os.path.join(os.path.dirname(__file__), "audio/sample2.flac"),
]
# skip if there is no file
examples = [e for e in examples if os.path.isfile(e)]
demo = gr.Interface(
gr_wrapper,
gr.Audio(sources=["upload", "microphone"], type="filepath"),
"text",
examples=examples,
)
demo.launch()
if __name__ == "__main__":
main()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
-f requirements.txt
gradio
\ No newline at end of file
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
accelerate
datasets
optimum[onnxruntime]
pydub
transformers
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the 'Software'), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Whisper\n",
"\n",
"The following example will show how to run `Whisper` with `MIGraphX`.\n",
"\n",
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"%pip install accelerate datasets optimum[onnxruntime] transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to download the model.\n",
"\n",
"The attention_mask for decoder is not exposed by default, but required to work with MIGraphX.\n",
"The following script will do that:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# download and export models\n",
"from download_whisper import export\n",
"export()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx\n",
"\n",
"import numpy as np\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/whisper-tiny.en_modified` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_mgx_model(name, shapes):\n",
" file = f\"models/whisper-tiny.en_modified/{name}_model\"\n",
" print(f\"Loading {name} model from {file}\")\n",
" if os.path.isfile(f\"{file}.mxr\"):\n",
" print(\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(f\"Saving {name} model to mxr file...\")\n",
" mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(f\"No {name} model found. Please download it and re-try.\")\n",
" os.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encoder_model = load_mgx_model(\"encoder\", {\"input_features\": [1, 80, 3000]})\n",
"decoder_model = load_mgx_model(\n",
" \"decoder\", {\n",
" \"input_ids\": [1, 448],\n",
" \"attention_mask\": [1, 448],\n",
" \"encoder_hidden_states\": [1, 1500, 384]\n",
" })"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the processor from the original source.\n",
"It will be used to get feature embeddings from the audio data and decode the output tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import WhisperProcessor\n",
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny.en\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to get audio data.\n",
"For testing purposes, we will use Hugging Face's dummy samples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\",\n",
" \"clean\",\n",
" split=\"validation\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step will be to get the input features from the audio data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_input_features_from_sample(sample_data, sampling_rate):\n",
" return processor(sample_data,\n",
" sampling_rate=sampling_rate,\n",
" return_tensors=\"np\").input_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will encode these, and use them in the decoding step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def encode_features(input_features):\n",
" return np.array(\n",
" encoder_model.run(\n",
" {\"input_features\": input_features.astype(np.float32)})[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The decoding process will be explained later in `generate`.\n",
"\n",
"The decoder model will expect the encoded features, the input ids (decoded tokens), and the attention mask to ignore parts as needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def decode_step(input_ids, attention_mask, hidden_states):\n",
" return np.array(\n",
" decoder_model.run({\n",
" \"input_ids\":\n",
" input_ids.astype(np.int64),\n",
" \"attention_mask\":\n",
" attention_mask.astype(np.int64),\n",
" \"encoder_hidden_states\":\n",
" hidden_states.astype(np.float32)\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following parameters are from [whisper-tiny.en's config](https://huggingface.co/openai/whisper-tiny.en/blob/main/config.json).\n",
"\n",
"You might need to change them if you change the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model params\n",
"decoder_start_token_id = 50257 # <|startoftranscript|>\n",
"eos_token_id = 50256 # \"<|endoftext|>\"\n",
"notimestamps = 50362 # <|notimestamps|>\n",
"max_length = 448\n",
"sot = [decoder_start_token_id, notimestamps]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To kickstart the decoding, we will provide the `<|startoftranscript|>` and `<|notimestamps|>` tokens.\n",
"\n",
"Fill up the remaining tokens with `<|endoftext|>` and mask to ignore them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def initial_decoder_inputs():\n",
" input_ids = np.array([sot + [eos_token_id] * (max_length - len(sot))])\n",
" # 0 masked | 1 un-masked\n",
" attention_mask = np.array([[1] * len(sot) + [0] * (max_length - len(sot))])\n",
" return (input_ids, attention_mask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally the text generation part.\n",
"\n",
"With each decoding step, we will get the probabilities for the next token. We greedily get best match, add it to the decoded tokens and unmask it.\n",
"\n",
"If the token is `<|endoftext|>`, we finished with the transcribing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate(input_features):\n",
" hidden_states = encode_features(input_features)\n",
" input_ids, attention_mask = initial_decoder_inputs()\n",
" for timestep in range(len(sot) - 1, max_length):\n",
" # get logits for the current timestep\n",
" logits = decode_step(input_ids, attention_mask, hidden_states)\n",
" # greedily get the highest probable token\n",
" new_token = np.argmax(logits[0][timestep])\n",
"\n",
" # add it to the tokens and unmask it\n",
" input_ids[0][timestep + 1] = new_token\n",
" attention_mask[0][timestep + 1] = 1\n",
"\n",
" print(\"Transcribing: \" + ''.join(\n",
" processor.decode(input_ids[0][:timestep + 1],\n",
" skip_special_tokens=True)),\n",
" end='\\r')\n",
"\n",
" if new_token == eos_token_id:\n",
" print(flush=True)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To test this, we will get the fist audio from the dataset.\n",
"\n",
"Feel free to change it and experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sample = ds[0][\"audio\"] # or load it from file\n",
"data, sampling_rate = sample[\"array\"], sample[\"sampling_rate\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_features = get_input_features_from_sample(data, sampling_rate)\n",
"generate(input_features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result should be:\n",
"\n",
"`Transcribing: Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.`"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment