{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The MIT License (MIT)\n", "#\n", "# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", "# of this software and associated documentation files (the 'Software'), to deal\n", "# in the Software without restriction, including without limitation the rights\n", "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", "# copies of the Software, and to permit persons to whom the Software is\n", "# furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", "# THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Whisper\n", "\n", "The following example will show how to run `Whisper` with `MIGraphX`.\n", "\n", "Install the required dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "%pip install accelerate datasets optimum[onnxruntime] transformers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use optimum to download the model.\n", "\n", "The attention_mask for decoder is not exposed by default, but required to work with MIGraphX.\n", "The following script will do that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# download and export models\n", "from download_whisper import export\n", "export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now it is time to load these models with python.\n", "\n", "First, we make sure that MIGraphX module is found in the python path." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n", "if mgx_lib_path not in sys.path:\n", " sys.path.append(mgx_lib_path)\n", "import migraphx as mgx\n", "\n", "import numpy as np\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, a helper method to load and cache the models.\n", "\n", "This will use the `models/whisper-tiny.en_modified` path. If you changed it, make sure to update here as well." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_mgx_model(name, shapes):\n", " file = f\"models/whisper-tiny.en_modified/{name}_model\"\n", " print(f\"Loading {name} model from {file}\")\n", " if os.path.isfile(f\"{file}.mxr\"):\n", " print(\"Found mxr, loading it...\")\n", " model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n", " elif os.path.isfile(f\"{file}.onnx\"):\n", " print(\"Parsing from onnx file...\")\n", " model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n", " model.compile(mgx.get_target(\"gpu\"))\n", " print(f\"Saving {name} model to mxr file...\")\n", " mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n", " else:\n", " print(f\"No {name} model found. Please download it and re-try.\")\n", " sys.exit(1)\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With that, we can load the models. This could take several minutes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "encoder_model = load_mgx_model(\"encoder\", {\"input_features\": [1, 80, 3000]})\n", "decoder_model = load_mgx_model(\n", " \"decoder\", {\n", " \"input_ids\": [1, 448],\n", " \"attention_mask\": [1, 448],\n", " \"encoder_hidden_states\": [1, 1500, 384]\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Time to load the processor from the original source.\n", "It will be used to get feature embeddings from the audio data and decode the output tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperProcessor\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny.en\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will define all the steps one by one, to make the last step short and simple.\n", "\n", "The first step will be to get audio data.\n", "For testing purposes, we will use Hugging Face's dummy samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\",\n", " \"clean\",\n", " split=\"validation\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next step will be to get the input features from the audio data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_input_features_from_sample(sample_data, sampling_rate):\n", " return processor(sample_data,\n", " sampling_rate=sampling_rate,\n", " return_tensors=\"np\").input_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will encode these, and use them in the decoding step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def encode_features(input_features):\n", " return np.array(\n", " encoder_model.run(\n", " {\"input_features\": input_features.astype(np.float32)})[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The decoding process will be explained later in `generate`.\n", "\n", "The decoder model will expect the encoded features, the input ids (decoded tokens), and the attention mask to ignore parts as needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def decode_step(input_ids, attention_mask, hidden_states):\n", " return np.array(\n", " decoder_model.run({\n", " \"input_ids\":\n", " input_ids.astype(np.int64),\n", " \"attention_mask\":\n", " attention_mask.astype(np.int64),\n", " \"encoder_hidden_states\":\n", " hidden_states.astype(np.float32)\n", " })[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following parameters are from [whisper-tiny.en's config](https://huggingface.co/openai/whisper-tiny.en/blob/main/config.json).\n", "\n", "You might need to change them if you change the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# model params\n", "decoder_start_token_id = 50257 # <|startoftranscript|>\n", "eos_token_id = 50256 # \"<|endoftext|>\"\n", "notimestamps = 50362 # <|notimestamps|>\n", "max_length = 448\n", "sot = [decoder_start_token_id, notimestamps]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To kickstart the decoding, we will provide the `<|startoftranscript|>` and `<|notimestamps|>` tokens.\n", "\n", "Fill up the remaining tokens with `<|endoftext|>` and mask to ignore them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def initial_decoder_inputs():\n", " input_ids = np.array([sot + [eos_token_id] * (max_length - len(sot))])\n", " # 0 masked | 1 un-masked\n", " attention_mask = np.array([[1] * len(sot) + [0] * (max_length - len(sot))])\n", " return (input_ids, attention_mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally the text generation part.\n", "\n", "With each decoding step, we will get the probabilities for the next token. We greedily get best match, add it to the decoded tokens and unmask it.\n", "\n", "If the token is `<|endoftext|>`, we finished with the transcribing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate(input_features):\n", " hidden_states = encode_features(input_features)\n", " input_ids, attention_mask = initial_decoder_inputs()\n", " for timestep in range(len(sot) - 1, max_length):\n", " # get logits for the current timestep\n", " logits = decode_step(input_ids, attention_mask, hidden_states)\n", " # greedily get the highest probable token\n", " new_token = np.argmax(logits[0][timestep])\n", "\n", " # add it to the tokens and unmask it\n", " input_ids[0][timestep + 1] = new_token\n", " attention_mask[0][timestep + 1] = 1\n", "\n", " print(\"Transcribing: \" + ''.join(\n", " processor.decode(input_ids[0][:timestep + 1],\n", " skip_special_tokens=True)),\n", " end='\\r')\n", "\n", " if new_token == eos_token_id:\n", " print(flush=True)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To test this, we will get the fist audio from the dataset.\n", "\n", "Feel free to change it and experiment." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sample = ds[0][\"audio\"] # or load it from file\n", "data, sampling_rate = sample[\"array\"], sample[\"sampling_rate\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_features = get_input_features_from_sample(data, sampling_rate)\n", "generate(input_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result should be:\n", "\n", "`Transcribing: Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.`" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }