Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a3363539
Unverified
Commit
a3363539
authored
Nov 04, 2021
by
moto
Committed by
GitHub
Nov 04, 2021
Browse files
Add Sphinx-gallery to doc (#1967)
parent
65dbf2d2
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
802 additions
and
1 deletion
+802
-1
.flake8
.flake8
+1
-1
.gitignore
.gitignore
+2
-0
docs/requirements.txt
docs/requirements.txt
+3
-0
docs/source/backend.rst
docs/source/backend.rst
+2
-0
docs/source/compliance.kaldi.rst
docs/source/compliance.kaldi.rst
+2
-0
docs/source/conf.py
docs/source/conf.py
+14
-0
docs/source/datasets.rst
docs/source/datasets.rst
+2
-0
docs/source/functional.rst
docs/source/functional.rst
+2
-0
docs/source/index.rst
docs/source/index.rst
+5
-0
docs/source/kaldi_io.rst
docs/source/kaldi_io.rst
+2
-0
docs/source/models.rst
docs/source/models.rst
+2
-0
docs/source/pipelines.rst
docs/source/pipelines.rst
+8
-0
docs/source/prototype.rst
docs/source/prototype.rst
+2
-0
docs/source/sox_effects.rst
docs/source/sox_effects.rst
+2
-0
docs/source/transforms.rst
docs/source/transforms.rst
+2
-0
docs/source/utils.rst
docs/source/utils.rst
+2
-0
examples/gallery/.gitignore
examples/gallery/.gitignore
+3
-0
examples/gallery/wav2vec2/README.rst
examples/gallery/wav2vec2/README.rst
+2
-0
examples/gallery/wav2vec2/forced_alignment_tutorial.py
examples/gallery/wav2vec2/forced_alignment_tutorial.py
+444
-0
examples/gallery/wav2vec2/speech_recognition_pipeline_tutorial.py
.../gallery/wav2vec2/speech_recognition_pipeline_tutorial.py
+300
-0
No files found.
.flake8
View file @
a3363539
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E741,F405,W503,W504,F999
exclude = build,docs/source,_ext,third_party
exclude = build,docs/source,_ext,third_party
,examples/gallery
.gitignore
View file @
a3363539
...
...
@@ -69,6 +69,8 @@ instance/
# Sphinx documentation
docs/_build/
docs/src/
docs/source/auto_examples
docs/source/gen_modules
# PyBuilder
target/
...
...
docs/requirements.txt
View file @
a3363539
...
...
@@ -4,3 +4,6 @@ sphinxcontrib.katex
sphinxcontrib.bibtex
matplotlib
pyparsing<3,>=2.0.2
sphinx_gallery
IPython
deep-phonemizer
docs/source/backend.rst
View file @
a3363539
...
...
@@ -3,6 +3,8 @@
torchaudio.backend
==================
.. py:module:: torchaudio.backend
Overview
~~~~~~~~
...
...
docs/source/compliance.kaldi.rst
View file @
a3363539
...
...
@@ -6,6 +6,8 @@ torchaudio.compliance.kaldi
.. currentmodule:: torchaudio.compliance.kaldi
.. py:module:: torchaudio.compliance.kaldi
The useful processing operations of kaldi_ can be performed with torchaudio.
Various functions with identical parameters are given so that torchaudio can
produce similar outputs.
...
...
docs/source/conf.py
View file @
a3363539
...
...
@@ -42,6 +42,7 @@ extensions = [
'sphinx.ext.viewcode'
,
'sphinxcontrib.katex'
,
'sphinxcontrib.bibtex'
,
'sphinx_gallery.gen_gallery'
,
]
# katex options
...
...
@@ -58,6 +59,19 @@ delimiters : [
bibtex_bibfiles
=
[
'refs.bib'
]
sphinx_gallery_conf
=
{
'examples_dirs'
:
[
'../../examples/gallery/wav2vec2'
,
],
'gallery_dirs'
:
[
'auto_examples/wav2vec2'
,
],
'filename_pattern'
:
'tutorial.py'
,
'backreferences_dir'
:
'gen_modules/backreferences'
,
'doc_module'
:
(
'torchaudio'
,),
}
autosummary_generate
=
True
napoleon_use_ivar
=
True
napoleon_numpy_docstring
=
False
napoleon_google_docstring
=
True
...
...
docs/source/datasets.rst
View file @
a3363539
torchaudio.datasets
====================
.. py:module:: torchaudio.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
and have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
...
...
docs/source/functional.rst
View file @
a3363539
...
...
@@ -4,6 +4,8 @@
torchaudio.functional
=====================
.. py:module:: torchaudio.functional
.. currentmodule:: torchaudio.functional
Functions to perform common audio operations.
...
...
docs/source/index.rst
View file @
a3363539
...
...
@@ -42,6 +42,11 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
utils
prototype
.. toctree::
:maxdepth: 2
:caption: Tutorials
auto_examples/wav2vec2/index
.. toctree::
:maxdepth: 1
...
...
docs/source/kaldi_io.rst
View file @
a3363539
...
...
@@ -4,6 +4,8 @@
torchaudio.kaldi_io
======================
.. py:module:: torchaudio.kaldi_io
.. currentmodule:: torchaudio.kaldi_io
To use this module, the dependency kaldi_io_ needs to be installed.
...
...
docs/source/models.rst
View file @
a3363539
...
...
@@ -4,6 +4,8 @@
torchaudio.models
=================
.. py:module:: torchaudio.models
.. currentmodule:: torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks.
...
...
docs/source/pipelines.rst
View file @
a3363539
...
...
@@ -3,6 +3,8 @@ torchaudio.pipelines
.. currentmodule:: torchaudio.pipelines
.. py:module:: torchaudio.pipelines
The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights.
wav2vec 2.0 / HuBERT - Representation Learning
...
...
@@ -73,6 +75,9 @@ HUBERT_XLARGE
wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
Wav2Vec2ASRBundle
~~~~~~~~~~~~~~~~~
.. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
...
...
@@ -80,6 +85,9 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
.. automethod:: get_labels
.. minigallery:: torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
:add-heading: Examples using ``Wav2Vec2ASRBundle``
:heading-level: ~
WAV2VEC2_ASR_BASE_10M
~~~~~~~~~~~~~~~~~~~~~
...
...
docs/source/prototype.rst
View file @
a3363539
...
...
@@ -4,6 +4,8 @@
torchaudio.prototype.emformer
=============================
.. py:module:: torchaudio.prototype.emformer
.. currentmodule:: torchaudio.prototype.emformer
Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_
...
...
docs/source/sox_effects.rst
View file @
a3363539
...
...
@@ -3,6 +3,8 @@
torchaudio.sox_effects
======================
.. py:module:: torchaudio.sox_effects
.. currentmodule:: torchaudio.sox_effects
Resource initialization / shutdown
...
...
docs/source/transforms.rst
View file @
a3363539
...
...
@@ -4,6 +4,8 @@
torchaudio.transforms
======================
.. py:module:: torchaudio.transforms
.. currentmodule:: torchaudio.transforms
Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential`
...
...
docs/source/utils.rst
View file @
a3363539
torchaudio.utils
================
.. py:module:: torchaudio.utils
torchaudio.utils.sox_utils
~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
examples/gallery/.gitignore
0 → 100644
View file @
a3363539
*.*
!*.rst
!*.py
examples/gallery/wav2vec2/README.rst
0 → 100644
View file @
a3363539
Wav2Vec2 Tutorials
==================
examples/gallery/wav2vec2/forced_alignment_tutorial.py
0 → 100644
View file @
a3363539
"""
Forced Alignment with Wav2Vec2
==============================
**Author** `Moto Hira <moto@fb.com>`__
This tutorial shows how to align transcript to speech with
``torchaudio``, using CTC segmentation algorithm described in
`CTC-Segmentation of Large Corpora for German End-to-end Speech
Recognition <https://arxiv.org/abs/2007.09127>`__.
"""
######################################################################
# Overview
# --------
#
# The process of alignment looks like the following.
#
# 1. Estimate the frame-wise label probability from audio waveform
# 2. Generate the trellis matrix which represents the probability of
# labels aligned at time step.
# 3. Find the most likely path from the trellis matrix.
#
# In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
# acoustic feature extraction.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import
os
from
dataclasses
import
dataclass
import
torch
import
torchaudio
import
requests
import
matplotlib
import
matplotlib.pyplot
as
plt
import
IPython
matplotlib
.
rcParams
[
'figure.figsize'
]
=
[
16.0
,
4.8
]
torch
.
random
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
print
(
torch
.
__version__
)
print
(
torchaudio
.
__version__
)
print
(
device
)
SPEECH_URL
=
'https://download.pytorch.org/torchaudio/test-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
SPEECH_FILE
=
'speech.flac'
if
not
os
.
path
.
exists
(
SPEECH_FILE
):
with
open
(
SPEECH_FILE
,
'wb'
)
as
file
:
file
.
write
(
requests
.
get
(
SPEECH_URL
).
content
)
######################################################################
# Generate frame-wise label probability
# -------------------------------------
#
# The first step is to generate the label class porbability of each aduio
# frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
# :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
#
# ``torchaudio`` provides easy access to pretrained models with associated
# labels.
#
# .. note::
#
# In the subsequent sections, we will compute the probability in
# log-domain to avoid numerical instability. For this purpose, we
# normalize the ``emission`` with :py:func:`torch.log_softmax`.
#
bundle
=
torchaudio
.
pipelines
.
WAV2VEC2_ASR_BASE_960H
model
=
bundle
.
get_model
().
to
(
device
)
labels
=
bundle
.
get_labels
()
with
torch
.
inference_mode
():
waveform
,
_
=
torchaudio
.
load
(
SPEECH_FILE
)
emissions
,
_
=
model
(
waveform
.
to
(
device
))
emissions
=
torch
.
log_softmax
(
emissions
,
dim
=-
1
)
emission
=
emissions
[
0
].
cpu
().
detach
()
################################################################################
# Visualization
################################################################################
print
(
labels
)
plt
.
imshow
(
emission
.
T
)
plt
.
colorbar
()
plt
.
title
(
"Frame-wise class probability"
)
plt
.
xlabel
(
"Time"
)
plt
.
ylabel
(
"Labels"
)
plt
.
show
()
######################################################################
# Generate alignment probability (trellis)
# ----------------------------------------
#
# From the emission matrix, next we generate the trellis which represents
# the probability of transcript labels occur at each time frame.
#
# Trellis is 2D matrix with time axis and label axis. The label axis
# represents the transcript that we are aligning. In the following, we use
# :math:`t` to denote the index in time axis and :math:`j` to denote the
# index in label axis. :math:`c_j` represents the label at label index
# :math:`j`.
#
# To generate, the probability of time step :math:`t+1`, we look at the
# trellis from time step :math:`t` and emission at time step :math:`t+1`.
# There are two path to reach to time step :math:`t+1` with label
# :math:`c_{j+1}`. The first one is the case where the label was
# :math:`c_{j+1}` at :math:`t` and there was no label change from
# :math:`t` to :math:`t+1`. The other case is where the label was
# :math:`c_j` at :math:`t` and it transitioned to the next label
# :math:`c_{j+1}` at :math:`t+1`.
#
# The follwoing diagram illustrates this transition.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png
#
# Since we are looking for the most likely transitions, we take the more
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
#
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
#
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
# represents the probability of label :math:`c_j` at time step :math:`t`.
# :math:`repeat` represents the blank token from CTC formulation. (For the
# detail of CTC algorithm, please refer to the *Sequence Modeling with CTC*
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
#
transcript
=
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
dictionary
=
{
c
:
i
for
i
,
c
in
enumerate
(
labels
)}
tokens
=
[
dictionary
[
c
]
for
c
in
transcript
]
print
(
list
(
zip
(
transcript
,
tokens
)))
def
get_trellis
(
emission
,
tokens
,
blank_id
=
0
):
num_frame
=
emission
.
size
(
0
)
num_tokens
=
len
(
tokens
)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis
=
torch
.
full
((
num_frame
+
1
,
num_tokens
+
1
),
-
float
(
'inf'
))
trellis
[:,
0
]
=
0
for
t
in
range
(
num_frame
):
trellis
[
t
+
1
,
1
:]
=
torch
.
maximum
(
# Score for staying at the same token
trellis
[
t
,
1
:]
+
emission
[
t
,
blank_id
],
# Score for changing to the next token
trellis
[
t
,
:
-
1
]
+
emission
[
t
,
tokens
],
)
return
trellis
trellis
=
get_trellis
(
emission
,
tokens
)
################################################################################
# Visualization
################################################################################
plt
.
imshow
(
trellis
[
1
:,
1
:].
T
,
origin
=
'lower'
)
plt
.
annotate
(
"- Inf"
,
(
trellis
.
size
(
1
)
/
5
,
trellis
.
size
(
1
)
/
1.5
))
plt
.
colorbar
()
plt
.
show
()
######################################################################
# In the above visualization, we can see that there is a trace of high
# probability crossing the matrix diagonally.
#
######################################################################
# Find the most likely path (backtracking)
# ----------------------------------------
#
# Once the trellis is generated, we will traverse it following the
# elements with high probability.
#
# We will start from the last label index with the time step of highest
# probability, then, we traverse back in time, picking stay
# (:math:`c_j \rightarrow c_j`) or transition
# (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition
# probability :math:`k_{t, j} p(t+1, c_{j+1})` or
# :math:`k_{t, j+1} p(t+1, repeat)`.
#
# Transition is done once the label reaches the beginning.
#
# The trellis matrix is used for path-finding, but for the final
# probability of each segment, we take the frame-wise probability from
# emission matrix.
#
@
dataclass
class
Point
:
token_index
:
int
time_index
:
int
score
:
float
def
backtrack
(
trellis
,
emission
,
tokens
,
blank_id
=
0
):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When refering to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when refering to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j
=
trellis
.
size
(
1
)
-
1
t_start
=
torch
.
argmax
(
trellis
[:,
j
]).
item
()
path
=
[]
for
t
in
range
(
t_start
,
0
,
-
1
):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed
=
trellis
[
t
-
1
,
j
]
+
emission
[
t
-
1
,
blank_id
]
# Score for token changing from C-1 at T-1 to J at T.
changed
=
trellis
[
t
-
1
,
j
-
1
]
+
emission
[
t
-
1
,
tokens
[
j
-
1
]]
# 2. Store the path with frame-wise probability.
prob
=
emission
[
t
-
1
,
tokens
[
j
-
1
]
if
changed
>
stayed
else
0
].
exp
().
item
()
# Return token index and time index in non-trellis coordinate.
path
.
append
(
Point
(
j
-
1
,
t
-
1
,
prob
))
# 3. Update the token
if
changed
>
stayed
:
j
-=
1
if
j
==
0
:
break
else
:
raise
ValueError
(
'Failed to align'
)
return
path
[::
-
1
]
path
=
backtrack
(
trellis
,
emission
,
tokens
)
print
(
path
)
################################################################################
# Visualization
################################################################################
def
plot_trellis_with_path
(
trellis
,
path
):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path
=
trellis
.
clone
()
for
i
,
p
in
enumerate
(
path
):
trellis_with_path
[
p
.
time_index
,
p
.
token_index
]
=
float
(
'nan'
)
plt
.
imshow
(
trellis_with_path
[
1
:,
1
:].
T
,
origin
=
'lower'
)
plot_trellis_with_path
(
trellis
,
path
)
plt
.
title
(
"The path found by backtracking"
)
plt
.
show
()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
# probability for the merged segments.
#
# Merge the labels
@
dataclass
class
Segment
:
label
:
str
start
:
int
end
:
int
score
:
float
def
__repr__
(
self
):
return
f
"
{
self
.
label
}
\t
(
{
self
.
score
:
4.2
f
}
): [
{
self
.
start
:
5
d
}
,
{
self
.
end
:
5
d
}
)"
@
property
def
length
(
self
):
return
self
.
end
-
self
.
start
def
merge_repeats
(
path
):
i1
,
i2
=
0
,
0
segments
=
[]
while
i1
<
len
(
path
):
while
i2
<
len
(
path
)
and
path
[
i1
].
token_index
==
path
[
i2
].
token_index
:
i2
+=
1
score
=
sum
(
path
[
k
].
score
for
k
in
range
(
i1
,
i2
))
/
(
i2
-
i1
)
segments
.
append
(
Segment
(
transcript
[
path
[
i1
].
token_index
],
path
[
i1
].
time_index
,
path
[
i2
-
1
].
time_index
+
1
,
score
))
i1
=
i2
return
segments
segments
=
merge_repeats
(
path
)
for
seg
in
segments
:
print
(
seg
)
################################################################################
# Visualization
################################################################################
def
plot_trellis_with_segments
(
trellis
,
segments
,
transcript
):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path
=
trellis
.
clone
()
for
i
,
seg
in
enumerate
(
segments
):
if
seg
.
label
!=
'|'
:
trellis_with_path
[
seg
.
start
+
1
:
seg
.
end
+
1
,
i
+
1
]
=
float
(
'nan'
)
fig
,
[
ax1
,
ax2
]
=
plt
.
subplots
(
2
,
1
,
figsize
=
(
16
,
9.5
))
ax1
.
set_title
(
"Path, label and probability for each label"
)
ax1
.
imshow
(
trellis_with_path
.
T
,
origin
=
'lower'
)
ax1
.
set_xticks
([])
for
i
,
seg
in
enumerate
(
segments
):
if
seg
.
label
!=
'|'
:
ax1
.
annotate
(
seg
.
label
,
(
seg
.
start
+
.
7
,
i
+
0.3
),
weight
=
'bold'
)
ax1
.
annotate
(
f
'
{
seg
.
score
:.
2
f
}
'
,
(
seg
.
start
-
.
3
,
i
+
4.3
))
ax2
.
set_title
(
"Label probability with and without repetation"
)
xs
,
hs
,
ws
=
[],
[],
[]
for
seg
in
segments
:
if
seg
.
label
!=
'|'
:
xs
.
append
((
seg
.
end
+
seg
.
start
)
/
2
+
.
4
)
hs
.
append
(
seg
.
score
)
ws
.
append
(
seg
.
end
-
seg
.
start
)
ax2
.
annotate
(
seg
.
label
,
(
seg
.
start
+
.
8
,
-
0.07
),
weight
=
'bold'
)
ax2
.
bar
(
xs
,
hs
,
width
=
ws
,
color
=
'gray'
,
alpha
=
0.5
,
edgecolor
=
'black'
)
xs
,
hs
=
[],
[]
for
p
in
path
:
label
=
transcript
[
p
.
token_index
]
if
label
!=
'|'
:
xs
.
append
(
p
.
time_index
+
1
)
hs
.
append
(
p
.
score
)
ax2
.
bar
(
xs
,
hs
,
width
=
0.5
,
alpha
=
0.5
)
ax2
.
axhline
(
0
,
color
=
'black'
)
ax2
.
set_xlim
(
ax1
.
get_xlim
())
ax2
.
set_ylim
(
-
0.1
,
1.1
)
plot_trellis_with_segments
(
trellis
,
segments
,
transcript
)
plt
.
tight_layout
()
plt
.
show
()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
# Then, finally, we segment the original audio into segmented audio and
# listen to them to see if the segmentation is correct.
#
# Merge words
def
merge_words
(
segments
,
separator
=
'|'
):
words
=
[]
i1
,
i2
=
0
,
0
while
i1
<
len
(
segments
):
if
i2
>=
len
(
segments
)
or
segments
[
i2
].
label
==
separator
:
if
i1
!=
i2
:
segs
=
segments
[
i1
:
i2
]
word
=
''
.
join
([
seg
.
label
for
seg
in
segs
])
score
=
sum
(
seg
.
score
*
seg
.
length
for
seg
in
segs
)
/
sum
(
seg
.
length
for
seg
in
segs
)
words
.
append
(
Segment
(
word
,
segments
[
i1
].
start
,
segments
[
i2
-
1
].
end
,
score
))
i1
=
i2
+
1
i2
=
i1
else
:
i2
+=
1
return
words
word_segments
=
merge_words
(
segments
)
for
word
in
word_segments
:
print
(
word
)
################################################################################
# Visualization
################################################################################
def
plot_alignments
(
trellis
,
segments
,
word_segments
,
waveform
):
trellis_with_path
=
trellis
.
clone
()
for
i
,
seg
in
enumerate
(
segments
):
if
seg
.
label
!=
'|'
:
trellis_with_path
[
seg
.
start
+
1
:
seg
.
end
+
1
,
i
+
1
]
=
float
(
'nan'
)
fig
,
[
ax1
,
ax2
]
=
plt
.
subplots
(
2
,
1
,
figsize
=
(
16
,
9.5
))
ax1
.
imshow
(
trellis_with_path
[
1
:,
1
:].
T
,
origin
=
'lower'
)
ax1
.
set_xticks
([])
ax1
.
set_yticks
([])
for
word
in
word_segments
:
ax1
.
axvline
(
word
.
start
-
0.5
)
ax1
.
axvline
(
word
.
end
-
0.5
)
for
i
,
seg
in
enumerate
(
segments
):
if
seg
.
label
!=
'|'
:
ax1
.
annotate
(
seg
.
label
,
(
seg
.
start
,
i
+
0.3
))
ax1
.
annotate
(
f
'
{
seg
.
score
:.
2
f
}
'
,
(
seg
.
start
,
i
+
4
),
fontsize
=
8
)
# The original waveform
ratio
=
waveform
.
size
(
0
)
/
(
trellis
.
size
(
0
)
-
1
)
ax2
.
plot
(
waveform
)
for
word
in
word_segments
:
x0
=
ratio
*
word
.
start
x1
=
ratio
*
word
.
end
ax2
.
axvspan
(
x0
,
x1
,
alpha
=
0.1
,
color
=
'red'
)
ax2
.
annotate
(
f
'
{
word
.
score
:.
2
f
}
'
,
(
x0
,
0.8
))
for
seg
in
segments
:
if
seg
.
label
!=
'|'
:
ax2
.
annotate
(
seg
.
label
,
(
seg
.
start
*
ratio
,
0.9
))
xticks
=
ax2
.
get_xticks
()
plt
.
xticks
(
xticks
,
xticks
/
bundle
.
sample_rate
)
ax2
.
set_xlabel
(
'time [second]'
)
ax2
.
set_yticks
([])
ax2
.
set_ylim
(
-
1.0
,
1.0
)
ax2
.
set_xlim
(
0
,
waveform
.
size
(
-
1
))
plot_alignments
(
trellis
,
segments
,
word_segments
,
waveform
[
0
],)
plt
.
show
()
# Generate the audio for each segment
print
(
transcript
)
IPython
.
display
.
display
(
IPython
.
display
.
Audio
(
SPEECH_FILE
))
ratio
=
waveform
.
size
(
1
)
/
(
trellis
.
size
(
0
)
-
1
)
for
i
,
word
in
enumerate
(
word_segments
):
x0
=
int
(
ratio
*
word
.
start
)
x1
=
int
(
ratio
*
word
.
end
)
filename
=
f
"
{
i
}
_
{
word
.
label
}
.wav"
torchaudio
.
save
(
filename
,
waveform
[:,
x0
:
x1
],
bundle
.
sample_rate
)
print
(
f
"
{
word
.
label
}
:
{
x0
/
bundle
.
sample_rate
:.
3
f
}
-
{
x1
/
bundle
.
sample_rate
:.
3
f
}
"
)
IPython
.
display
.
display
(
IPython
.
display
.
Audio
(
filename
))
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to
# perform CTC segmentation for forced alignment.
#
examples/gallery/wav2vec2/speech_recognition_pipeline_tutorial.py
0 → 100644
View file @
a3363539
"""
Speech Recognition with Wav2Vec2
================================
**Author**: `Moto Hira <moto@fb.com>`__
This tutorial shows how to perform speech recognition using using
pre-trained models from wav2vec 2.0
[`paper <https://arxiv.org/abs/2006.11477>`__].
"""
######################################################################
# Overview
# --------
#
# The process of speech recognition looks like the following.
#
# 1. Extract the acoustic features from audio waveform
#
# 2. Estimate the class of the acoustic features frame-by-frame
#
# 3. Generate hypothesis from the sequence of the class probabilities
#
# Torchaudio provides easy access to the pre-trained weights and
# associated information, such as the expected sample rate and class
# labels. They are bundled together and available under
# ``torchaudio.pipelines`` module.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import
os
import
torch
import
torchaudio
import
requests
import
matplotlib
import
matplotlib.pyplot
as
plt
import
IPython
matplotlib
.
rcParams
[
'figure.figsize'
]
=
[
16.0
,
4.8
]
torch
.
random
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
print
(
torch
.
__version__
)
print
(
torchaudio
.
__version__
)
print
(
device
)
SPEECH_URL
=
"https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE
=
"speech.wav"
if
not
os
.
path
.
exists
(
SPEECH_FILE
):
with
open
(
SPEECH_FILE
,
'wb'
)
as
file
:
file
.
write
(
requests
.
get
(
SPEECH_URL
).
content
)
######################################################################
# Creating a pipeline
# -------------------
#
# First, we will create a Wav2Vec2 model that performs the feature
# extraction and the classification.
#
# There are two types of Wav2Vec2 pre-trained weights available in
# torchaudio. The ones fine-tuned for ASR task, and the ones not
# fine-tuned.
#
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
# are firstly trained with audio only for representation learning, then
# fine-tuned for a specific task with additional labels.
#
# The pre-trained weights without fine-tuning can be fine-tuned
# for other downstream tasks as well, but this tutorial does not
# cover that.
#
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
#
# There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained.
#
# The bundle object provides the interface to instantiate model and other
# information. Sampling rate and the class labels are found as follow.
#
bundle
=
torchaudio
.
pipelines
.
WAV2VEC2_ASR_BASE_960H
print
(
"Sample Rate:"
,
bundle
.
sample_rate
)
print
(
"Labels:"
,
bundle
.
get_labels
())
######################################################################
# Model can be constructed as following. This process will automatically
# fetch the pre-trained weights and load it into the model.
#
model
=
bundle
.
get_model
().
to
(
device
)
print
(
model
.
__class__
)
######################################################################
# Loading data
# ------------
#
# We will use the speech data from `VOiCES
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
# Creative Commos BY 4.0.
#
IPython
.
display
.
display
(
IPython
.
display
.
Audio
(
SPEECH_FILE
))
######################################################################
# To load data, we use :py:func:`torchaudio.load`.
#
# If the sampling rate is different from what the pipeline expects, then
# we can use :py:func:`torchaudio.functional.resample` for resampling.
#
# .. note::
#
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
#
waveform
,
sample_rate
=
torchaudio
.
load
(
SPEECH_FILE
)
waveform
=
waveform
.
to
(
device
)
if
sample_rate
!=
bundle
.
sample_rate
:
waveform
=
torchaudio
.
functional
.
resample
(
waveform
,
sample_rate
,
bundle
.
sample_rate
)
######################################################################
# Extracting acoustic features
# ----------------------------
#
# The next step is to extract acoustic features from the audio.
#
# .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
#
with
torch
.
inference_mode
():
features
,
_
=
model
.
extract_features
(
waveform
)
######################################################################
# The returned features is a list of tensors. Each tensor is the output of
# a transformer layer.
#
fig
,
ax
=
plt
.
subplots
(
len
(
features
),
1
,
figsize
=
(
16
,
4.3
*
len
(
features
)))
for
i
,
feats
in
enumerate
(
features
):
ax
[
i
].
imshow
(
feats
[
0
].
cpu
())
ax
[
i
].
set_title
(
f
"Feature from transformer layer
{
i
+
1
}
"
)
ax
[
i
].
set_xlabel
(
"Feature dimension"
)
ax
[
i
].
set_ylabel
(
"Frame (time-axis)"
)
plt
.
tight_layout
()
plt
.
show
()
######################################################################
# Feature classification
# ----------------------
#
# Once the acoustic features are extracted, the next step is to classify
# them into a set of categories.
#
# Wav2Vec2 model provides method to perform the feature extraction and
# classification in one step.
#
with
torch
.
inference_mode
():
emission
,
_
=
model
(
waveform
)
######################################################################
# The output is in the form of logits. It is not in the form of
# probability.
#
# Let’s visualize this.
#
plt
.
imshow
(
emission
[
0
].
cpu
().
T
)
plt
.
title
(
"Classification result"
)
plt
.
xlabel
(
"Frame (time-axis)"
)
plt
.
ylabel
(
"Class"
)
plt
.
show
()
print
(
"Class labels:"
,
bundle
.
get_labels
())
######################################################################
# We can see that there are strong indications to certain labels across
# the time line.
#
######################################################################
# Generating transcripts
# ----------------------
#
# From the sequence of label probabilities, now we want to generate
# transcripts. The process to generate hypotheses is often called
# “decoding”.
#
# Decoding is more elaborate than simple classification because
# decoding at certain time step can be affected by surrounding
# observations.
#
# For example, take a word like ``night`` and ``knight``. Even if their
# prior probability distribution are differnt (in typical conversations,
# ``night`` would occur way more often than ``knight``), to accurately
# generate transcripts with ``knight``, such as ``a knight with a sword``,
# the decoding process has to postpone the final decision until it sees
# enough context.
#
# There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models.
#
# In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated.
#
# We start by defining greedy decoding algorithm.
#
class
GreedyCTCDecoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
labels
,
blank
=
0
):
super
().
__init__
()
self
.
labels
=
labels
self
.
blank
=
blank
def
forward
(
self
,
emission
:
torch
.
Tensor
)
->
str
:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices
=
torch
.
argmax
(
emission
,
dim
=-
1
)
# [num_seq,]
indices
=
torch
.
unique_consecutive
(
indices
,
dim
=-
1
)
indices
=
[
i
for
i
in
indices
if
i
!=
self
.
blank
]
return
''
.
join
([
self
.
labels
[
i
]
for
i
in
indices
])
######################################################################
# Now create the decoder object and decode the transcript.
#
decoder
=
GreedyCTCDecoder
(
labels
=
bundle
.
get_labels
())
transcript
=
decoder
(
emission
[
0
])
######################################################################
# Let’s check the result and listen again to the audio.
#
print
(
transcript
)
IPython
.
display
.
display
(
IPython
.
display
.
Audio
(
SPEECH_FILE
))
######################################################################
# The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
# The detail of CTC loss is explained
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In
# decoding, these are simply ignored.
#
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines.
#
# ::
#
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
# emission = model(waveforms, ...)
#
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment