Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0fdb72ac
Commit
0fdb72ac
authored
Nov 30, 2023
by
Attila Dusnoki
Browse files
Add Whisper example
parent
e3e00547
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
921 additions
and
0 deletions
+921
-0
examples/transformers/README.md
examples/transformers/README.md
+3
-0
examples/transformers/python_whisper/README.md
examples/transformers/python_whisper/README.md
+79
-0
examples/transformers/python_whisper/asr.py
examples/transformers/python_whisper/asr.py
+208
-0
examples/transformers/python_whisper/download_samples.sh
examples/transformers/python_whisper/download_samples.sh
+31
-0
examples/transformers/python_whisper/download_whisper.py
examples/transformers/python_whisper/download_whisper.py
+101
-0
examples/transformers/python_whisper/gradio_app.py
examples/transformers/python_whisper/gradio_app.py
+56
-0
examples/transformers/python_whisper/gradio_reqirements.txt
examples/transformers/python_whisper/gradio_reqirements.txt
+25
-0
examples/transformers/python_whisper/requirements.txt
examples/transformers/python_whisper/requirements.txt
+28
-0
examples/transformers/python_whisper/whisper.ipynb
examples/transformers/python_whisper/whisper.ipynb
+390
-0
No files found.
examples/transformers/README.md
0 → 100644
View file @
0fdb72ac
# Transformers Inference Examples
-
[
Python Whisper
](
./python_whisper
)
examples/transformers/python_whisper/README.md
0 → 100644
View file @
0fdb72ac
# Whisper
This version was tested with
[
rocm 5.7
](
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0
)
revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See
[
whisper.ipynb
](
./whisper.ipynb
)
## Console application
To run the console application, follow these steps below.
Setup python environment
```
bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3
-m
venv w_venv
.
w_venv/bin/activate
```
Install dependencies
`ffmpeg`
needed to handle audio files.
```
bash
apt
install
ffmpeg
```
```
bash
pip
install
-r
requirements.txt
```
Use MIGraphX Python Module
```
bash
export
PYTHONPATH
=
/opt/rocm/lib:
$PYTHONPATH
```
Use the helper script to download with optimum.
The attention_mask for decoder is not exposed by default, but required to work with MIGraphX.
```
bash
python download_whisper.py
```
*Note: `models/whisper-tiny.en_modified` will be used in the scripts*
There are
*optional*
samples which can be downloaded. But the example can be tested without them.
```
bash
./download_samples.sh
```
Run the automatic-speech-recognition script with the following example input:
```
bash
python asr.py
--audio
audio/sample1.flac
--log-process
```
Or without any audio input to run the
[
Hugging Face dummy dataset
](
https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy
)
samples.
## Gradio application
Note: requires
`Console application`
to work
Install gradio dependencies
```
bash
pip
install
-r
gradio_requirements.txt
```
Usage
```
bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on
`http://127.0.0.1:7860`
.
examples/transformers/python_whisper/asr.py
0 → 100755
View file @
0fdb72ac
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from
argparse
import
ArgumentParser
from
transformers
import
WhisperProcessor
from
datasets
import
load_dataset
from
pydub
import
AudioSegment
import
migraphx
as
mgx
import
os
import
numpy
as
np
import
time
from
functools
import
wraps
# measurement helper
def
measure
(
fn
):
@
wraps
(
fn
)
def
measure_ms
(
*
args
,
**
kwargs
):
start_time
=
time
.
perf_counter_ns
()
result
=
fn
(
*
args
,
**
kwargs
)
end_time
=
time
.
perf_counter_ns
()
print
(
f
"Elapsed time for
{
fn
.
__name__
}
:
{
(
end_time
-
start_time
)
*
1e-6
:.
4
f
}
ms
\n
"
)
return
result
return
measure_ms
def
get_args
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
"-a"
,
"--audio"
,
type
=
str
,
help
=
"Path to audio file. Default: HF test dataset"
,
)
parser
.
add_argument
(
"-l"
,
"--log-process"
,
action
=
"store_true"
,
help
=
"Print the current state of transcribing."
,
)
return
parser
.
parse_args
()
class
WhisperMGX
():
def
__init__
(
self
):
model_id
=
"openai/whisper-tiny.en"
print
(
f
"Using
{
model_id
}
"
)
print
(
"Creating Whisper processor"
)
self
.
processor
=
WhisperProcessor
.
from_pretrained
(
model_id
)
self
.
decoder_start_token_id
=
50257
# <|startoftranscript|>
self
.
eos_token_id
=
50256
# "<|endoftext|>"
self
.
notimestamps
=
50362
# <|notimestamps|>
self
.
max_length
=
448
self
.
sot
=
[
self
.
decoder_start_token_id
,
self
.
notimestamps
]
print
(
"Load models..."
)
self
.
encoder_model
=
WhisperMGX
.
load_mgx_model
(
"encoder"
,
{
"input_features"
:
[
1
,
80
,
3000
]})
self
.
decoder_model
=
WhisperMGX
.
load_mgx_model
(
"decoder"
,
{
"input_ids"
:
[
1
,
self
.
max_length
],
"attention_mask"
:
[
1
,
self
.
max_length
],
"encoder_hidden_states"
:
[
1
,
1500
,
384
]
})
@
staticmethod
@
measure
def
load_audio_from_file
(
filepath
):
audio
=
AudioSegment
.
from_file
(
filepath
)
# Only 16k is supported
audio
=
audio
.
set_frame_rate
(
16000
)
data
=
np
.
array
(
audio
.
get_array_of_samples
(),
dtype
=
np
.
float32
)
data
/=
np
.
max
(
np
.
abs
(
data
))
return
data
,
audio
.
frame_rate
@
staticmethod
@
measure
def
load_mgx_model
(
name
,
shapes
):
file
=
f
"models/whisper-tiny.en_modified/
{
name
}
_model"
print
(
f
"Loading
{
name
}
model from
{
file
}
"
)
if
os
.
path
.
isfile
(
f
"
{
file
}
.mxr"
):
print
(
"Found mxr, loading it..."
)
model
=
mgx
.
load
(
f
"
{
file
}
.mxr"
,
format
=
"msgpack"
)
elif
os
.
path
.
isfile
(
f
"
{
file
}
.onnx"
):
print
(
"Parsing from onnx file..."
)
model
=
mgx
.
parse_onnx
(
f
"
{
file
}
.onnx"
,
map_input_dims
=
shapes
)
model
.
compile
(
mgx
.
get_target
(
"gpu"
))
print
(
f
"Saving
{
name
}
model to mxr file..."
)
mgx
.
save
(
model
,
f
"
{
file
}
.mxr"
,
format
=
"msgpack"
)
else
:
print
(
f
"No
{
name
}
model found. Please download it and re-try."
)
os
.
exit
(
1
)
return
model
@
property
def
initial_decoder_inputs
(
self
):
input_ids
=
np
.
array
([
self
.
sot
+
[
self
.
eos_token_id
]
*
(
self
.
max_length
-
len
(
self
.
sot
))
])
# 0 masked | 1 un-masked
attention_mask
=
np
.
array
([[
1
]
*
len
(
self
.
sot
)
+
[
0
]
*
(
self
.
max_length
-
len
(
self
.
sot
))])
return
(
input_ids
,
attention_mask
)
@
measure
def
get_input_features_from_sample
(
self
,
sample_data
,
sampling_rate
):
return
self
.
processor
(
sample_data
,
sampling_rate
=
sampling_rate
,
return_tensors
=
"np"
).
input_features
@
measure
def
encode_features
(
self
,
input_features
):
return
np
.
array
(
self
.
encoder_model
.
run
(
{
"input_features"
:
input_features
.
astype
(
np
.
float32
)})[
0
])
def
decode_step
(
self
,
input_ids
,
attention_mask
,
hidden_states
):
return
np
.
array
(
self
.
decoder_model
.
run
({
"input_ids"
:
input_ids
.
astype
(
np
.
int64
),
"attention_mask"
:
attention_mask
.
astype
(
np
.
int64
),
"encoder_hidden_states"
:
hidden_states
.
astype
(
np
.
float32
)
})[
0
])
@
measure
def
generate
(
self
,
input_features
,
log_process
=
False
):
hidden_states
=
self
.
encode_features
(
input_features
)
input_ids
,
attention_mask
=
self
.
initial_decoder_inputs
for
timestep
in
range
(
len
(
self
.
sot
)
-
1
,
self
.
max_length
):
# get logits for the current timestep
logits
=
self
.
decode_step
(
input_ids
,
attention_mask
,
hidden_states
)
# greedily get the highest probable token
new_token
=
np
.
argmax
(
logits
[
0
][
timestep
])
# add it to the tokens and unmask it
input_ids
[
0
][
timestep
+
1
]
=
new_token
attention_mask
[
0
][
timestep
+
1
]
=
1
if
log_process
:
print
(
"Transcribing: "
+
''
.
join
(
self
.
processor
.
decode
(
input_ids
[
0
][:
timestep
+
1
],
skip_special_tokens
=
True
)),
end
=
'
\r
'
)
if
new_token
==
self
.
eos_token_id
:
break
if
log_process
:
print
(
flush
=
True
)
return
''
.
join
(
self
.
processor
.
decode
(
input_ids
[
0
][:
timestep
+
1
],
skip_special_tokens
=
True
))
if
__name__
==
"__main__"
:
args
=
get_args
()
if
args
.
audio
:
data
,
fr
=
WhisperMGX
.
load_audio_from_file
(
args
.
audio
)
ds
=
[{
"audio"
:
{
"array"
:
data
,
"sampling_rate"
:
fr
}}]
else
:
# load dummy dataset and read audio files
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
w
=
WhisperMGX
()
for
idx
,
data
in
enumerate
(
ds
):
print
(
f
"#
{
idx
+
1
}
/
{
len
(
ds
)
}
Sample..."
)
sample
=
data
[
"audio"
]
input_features
=
w
.
get_input_features_from_sample
(
sample
[
"array"
],
sample
[
"sampling_rate"
])
result
=
w
.
generate
(
input_features
,
log_process
=
args
.
log_process
)
print
(
f
"Result:
{
result
}
"
)
examples/transformers/python_whisper/download_samples.sh
0 → 100755
View file @
0fdb72ac
#!/bin/bash
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
DIR
=
"
$(
dirname
"
$0
"
)
/audio/"
mkdir
-p
$DIR
wget
-q
https://cdn-media.huggingface.co/speech_samples/sample1.flac
-O
"
$DIR
/sample1.flac"
wget
-q
https://cdn-media.huggingface.co/speech_samples/sample2.flac
-O
"
$DIR
/sample2.flac"
echo
"Samples downloaded to
$DIR
"
\ No newline at end of file
examples/transformers/python_whisper/download_whisper.py
0 → 100644
View file @
0fdb72ac
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from
optimum.exporters.onnx
import
main_export
from
optimum.exporters.onnx.model_configs
import
WhisperOnnxConfig
from
transformers
import
AutoConfig
from
optimum.exporters.onnx.base
import
ConfigBehavior
from
typing
import
Dict
class
CustomWhisperOnnxConfig
(
WhisperOnnxConfig
):
@
property
def
inputs
(
self
)
->
Dict
[
str
,
Dict
[
int
,
str
]]:
common_inputs
=
{}
if
self
.
_behavior
is
ConfigBehavior
.
ENCODER
:
common_inputs
[
"input_features"
]
=
{
0
:
"batch_size"
,
1
:
"feature_size"
,
2
:
"encoder_sequence_length"
}
if
self
.
_behavior
is
ConfigBehavior
.
DECODER
:
common_inputs
[
"decoder_input_ids"
]
=
{
0
:
"batch_size"
,
1
:
"decoder_sequence_length"
}
common_inputs
[
"decoder_attention_mask"
]
=
{
0
:
"batch_size"
,
1
:
"decoder_sequence_length"
}
common_inputs
[
"encoder_outputs"
]
=
{
0
:
"batch_size"
,
1
:
"encoder_sequence_length"
}
return
common_inputs
@
property
def
torch_to_onnx_input_map
(
self
)
->
Dict
[
str
,
str
]:
if
self
.
_behavior
is
ConfigBehavior
.
DECODER
:
return
{
"decoder_input_ids"
:
"input_ids"
,
"decoder_attention_mask"
:
"attention_mask"
,
"encoder_outputs"
:
"encoder_hidden_states"
,
}
return
{}
def
export
():
model_id
=
"openai/whisper-tiny.en"
config
=
AutoConfig
.
from_pretrained
(
model_id
)
custom_whisper_onnx_config
=
CustomWhisperOnnxConfig
(
config
=
config
,
task
=
"automatic-speech-recognition"
,
)
encoder_config
=
custom_whisper_onnx_config
.
with_behavior
(
"encoder"
)
decoder_config
=
custom_whisper_onnx_config
.
with_behavior
(
"decoder"
,
use_past
=
False
)
custom_onnx_configs
=
{
"encoder_model"
:
encoder_config
,
"decoder_model"
:
decoder_config
,
}
output
=
"models/whisper-tiny.en_modified"
main_export
(
model_id
,
output
=
output
,
no_post_process
=
True
,
do_validation
=
False
,
custom_onnx_configs
=
custom_onnx_configs
)
print
(
f
"Done. Check
{
output
}
"
)
if
__name__
==
"__main__"
:
export
()
examples/transformers/python_whisper/gradio_app.py
0 → 100644
View file @
0fdb72ac
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from
asr
import
WhisperMGX
import
gradio
as
gr
import
os
def
main
():
# Note: This will load the models, which can take several minutes
w
=
WhisperMGX
()
def
gr_wrapper
(
audio
):
data
,
fr
=
WhisperMGX
.
load_audio_from_file
(
audio
)
input_features
=
w
.
get_input_features_from_sample
(
data
,
fr
)
return
w
.
generate
(
input_features
)
examples
=
[
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"audio/sample1.flac"
),
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"audio/sample2.flac"
),
]
# skip if there is no file
examples
=
[
e
for
e
in
examples
if
os
.
path
.
isfile
(
e
)]
demo
=
gr
.
Interface
(
gr_wrapper
,
gr
.
Audio
(
sources
=
[
"upload"
,
"microphone"
],
type
=
"filepath"
),
"text"
,
examples
=
examples
,
)
demo
.
launch
()
if
__name__
==
"__main__"
:
main
()
examples/transformers/python_whisper/gradio_reqirements.txt
0 → 100644
View file @
0fdb72ac
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
-f requirements.txt
gradio
\ No newline at end of file
examples/transformers/python_whisper/requirements.txt
0 → 100644
View file @
0fdb72ac
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
accelerate
datasets
optimum[onnxruntime]
pydub
transformers
\ No newline at end of file
examples/transformers/python_whisper/whisper.ipynb
0 → 100644
View file @
0fdb72ac
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the 'Software'), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Whisper\n",
"\n",
"The following example will show how to run `Whisper` with `MIGraphX`.\n",
"\n",
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"%pip install accelerate datasets optimum[onnxruntime] transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to download the model.\n",
"\n",
"The attention_mask for decoder is not exposed by default, but required to work with MIGraphX.\n",
"The following script will do that:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# download and export models\n",
"from download_whisper import export\n",
"export()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx\n",
"\n",
"import numpy as np\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/whisper-tiny.en_modified` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_mgx_model(name, shapes):\n",
" file = f\"models/whisper-tiny.en_modified/{name}_model\"\n",
" print(f\"Loading {name} model from {file}\")\n",
" if os.path.isfile(f\"{file}.mxr\"):\n",
" print(\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(f\"Saving {name} model to mxr file...\")\n",
" mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(f\"No {name} model found. Please download it and re-try.\")\n",
" os.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encoder_model = load_mgx_model(\"encoder\", {\"input_features\": [1, 80, 3000]})\n",
"decoder_model = load_mgx_model(\n",
" \"decoder\", {\n",
" \"input_ids\": [1, 448],\n",
" \"attention_mask\": [1, 448],\n",
" \"encoder_hidden_states\": [1, 1500, 384]\n",
" })"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the processor from the original source.\n",
"It will be used to get feature embeddings from the audio data and decode the output tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import WhisperProcessor\n",
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny.en\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to get audio data.\n",
"For testing purposes, we will use Hugging Face's dummy samples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\",\n",
" \"clean\",\n",
" split=\"validation\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step will be to get the input features from the audio data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_input_features_from_sample(sample_data, sampling_rate):\n",
" return processor(sample_data,\n",
" sampling_rate=sampling_rate,\n",
" return_tensors=\"np\").input_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will encode these, and use them in the decoding step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def encode_features(input_features):\n",
" return np.array(\n",
" encoder_model.run(\n",
" {\"input_features\": input_features.astype(np.float32)})[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The decoding process will be explained later in `generate`.\n",
"\n",
"The decoder model will expect the encoded features, the input ids (decoded tokens), and the attention mask to ignore parts as needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def decode_step(input_ids, attention_mask, hidden_states):\n",
" return np.array(\n",
" decoder_model.run({\n",
" \"input_ids\":\n",
" input_ids.astype(np.int64),\n",
" \"attention_mask\":\n",
" attention_mask.astype(np.int64),\n",
" \"encoder_hidden_states\":\n",
" hidden_states.astype(np.float32)\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following parameters are from [whisper-tiny.en's config](https://huggingface.co/openai/whisper-tiny.en/blob/main/config.json).\n",
"\n",
"You might need to change them if you change the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model params\n",
"decoder_start_token_id = 50257 # <|startoftranscript|>\n",
"eos_token_id = 50256 # \"<|endoftext|>\"\n",
"notimestamps = 50362 # <|notimestamps|>\n",
"max_length = 448\n",
"sot = [decoder_start_token_id, notimestamps]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To kickstart the decoding, we will provide the `<|startoftranscript|>` and `<|notimestamps|>` tokens.\n",
"\n",
"Fill up the remaining tokens with `<|endoftext|>` and mask to ignore them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def initial_decoder_inputs():\n",
" input_ids = np.array([sot + [eos_token_id] * (max_length - len(sot))])\n",
" # 0 masked | 1 un-masked\n",
" attention_mask = np.array([[1] * len(sot) + [0] * (max_length - len(sot))])\n",
" return (input_ids, attention_mask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally the text generation part.\n",
"\n",
"With each decoding step, we will get the probabilities for the next token. We greedily get best match, add it to the decoded tokens and unmask it.\n",
"\n",
"If the token is `<|endoftext|>`, we finished with the transcribing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate(input_features):\n",
" hidden_states = encode_features(input_features)\n",
" input_ids, attention_mask = initial_decoder_inputs()\n",
" for timestep in range(len(sot) - 1, max_length):\n",
" # get logits for the current timestep\n",
" logits = decode_step(input_ids, attention_mask, hidden_states)\n",
" # greedily get the highest probable token\n",
" new_token = np.argmax(logits[0][timestep])\n",
"\n",
" # add it to the tokens and unmask it\n",
" input_ids[0][timestep + 1] = new_token\n",
" attention_mask[0][timestep + 1] = 1\n",
"\n",
" print(\"Transcribing: \" + ''.join(\n",
" processor.decode(input_ids[0][:timestep + 1],\n",
" skip_special_tokens=True)),\n",
" end='\\r')\n",
"\n",
" if new_token == eos_token_id:\n",
" print(flush=True)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To test this, we will get the fist audio from the dataset.\n",
"\n",
"Feel free to change it and experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sample = ds[0][\"audio\"] # or load it from file\n",
"data, sampling_rate = sample[\"array\"], sample[\"sampling_rate\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_features = get_input_features_from_sample(data, sampling_rate)\n",
"generate(input_features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result should be:\n",
"\n",
"`Transcribing: Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.`"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment