Unverified Commit a3363539 authored by moto's avatar moto Committed by GitHub
Browse files

Add Sphinx-gallery to doc (#1967)

parent 65dbf2d2
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = E305,E402,E721,E741,F405,W503,W504,F999 ignore = E305,E402,E721,E741,F405,W503,W504,F999
exclude = build,docs/source,_ext,third_party exclude = build,docs/source,_ext,third_party,examples/gallery
...@@ -69,6 +69,8 @@ instance/ ...@@ -69,6 +69,8 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/src/ docs/src/
docs/source/auto_examples
docs/source/gen_modules
# PyBuilder # PyBuilder
target/ target/
......
...@@ -4,3 +4,6 @@ sphinxcontrib.katex ...@@ -4,3 +4,6 @@ sphinxcontrib.katex
sphinxcontrib.bibtex sphinxcontrib.bibtex
matplotlib matplotlib
pyparsing<3,>=2.0.2 pyparsing<3,>=2.0.2
sphinx_gallery
IPython
deep-phonemizer
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
torchaudio.backend torchaudio.backend
================== ==================
.. py:module:: torchaudio.backend
Overview Overview
~~~~~~~~ ~~~~~~~~
......
...@@ -6,6 +6,8 @@ torchaudio.compliance.kaldi ...@@ -6,6 +6,8 @@ torchaudio.compliance.kaldi
.. currentmodule:: torchaudio.compliance.kaldi .. currentmodule:: torchaudio.compliance.kaldi
.. py:module:: torchaudio.compliance.kaldi
The useful processing operations of kaldi_ can be performed with torchaudio. The useful processing operations of kaldi_ can be performed with torchaudio.
Various functions with identical parameters are given so that torchaudio can Various functions with identical parameters are given so that torchaudio can
produce similar outputs. produce similar outputs.
......
...@@ -42,6 +42,7 @@ extensions = [ ...@@ -42,6 +42,7 @@ extensions = [
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinxcontrib.katex', 'sphinxcontrib.katex',
'sphinxcontrib.bibtex', 'sphinxcontrib.bibtex',
'sphinx_gallery.gen_gallery',
] ]
# katex options # katex options
...@@ -58,6 +59,19 @@ delimiters : [ ...@@ -58,6 +59,19 @@ delimiters : [
bibtex_bibfiles = ['refs.bib'] bibtex_bibfiles = ['refs.bib']
sphinx_gallery_conf = {
'examples_dirs': [
'../../examples/gallery/wav2vec2',
],
'gallery_dirs': [
'auto_examples/wav2vec2',
],
'filename_pattern': 'tutorial.py',
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('torchaudio',),
}
autosummary_generate = True
napoleon_use_ivar = True napoleon_use_ivar = True
napoleon_numpy_docstring = False napoleon_numpy_docstring = False
napoleon_google_docstring = True napoleon_google_docstring = True
......
torchaudio.datasets torchaudio.datasets
==================== ====================
.. py:module:: torchaudio.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset` All datasets are subclasses of :class:`torch.utils.data.Dataset`
and have ``__getitem__`` and ``__len__`` methods implemented. and have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader` Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
torchaudio.functional torchaudio.functional
===================== =====================
.. py:module:: torchaudio.functional
.. currentmodule:: torchaudio.functional .. currentmodule:: torchaudio.functional
Functions to perform common audio operations. Functions to perform common audio operations.
......
...@@ -42,6 +42,11 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio ...@@ -42,6 +42,11 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
utils utils
prototype prototype
.. toctree::
:maxdepth: 2
:caption: Tutorials
auto_examples/wav2vec2/index
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
torchaudio.kaldi_io torchaudio.kaldi_io
====================== ======================
.. py:module:: torchaudio.kaldi_io
.. currentmodule:: torchaudio.kaldi_io .. currentmodule:: torchaudio.kaldi_io
To use this module, the dependency kaldi_io_ needs to be installed. To use this module, the dependency kaldi_io_ needs to be installed.
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
torchaudio.models torchaudio.models
================= =================
.. py:module:: torchaudio.models
.. currentmodule:: torchaudio.models .. currentmodule:: torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks. The models subpackage contains definitions of models for addressing common audio tasks.
......
...@@ -3,6 +3,8 @@ torchaudio.pipelines ...@@ -3,6 +3,8 @@ torchaudio.pipelines
.. currentmodule:: torchaudio.pipelines .. currentmodule:: torchaudio.pipelines
.. py:module:: torchaudio.pipelines
The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights. The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights.
wav2vec 2.0 / HuBERT - Representation Learning wav2vec 2.0 / HuBERT - Representation Learning
...@@ -73,6 +75,9 @@ HUBERT_XLARGE ...@@ -73,6 +75,9 @@ HUBERT_XLARGE
wav2vec 2.0 / HuBERT - Fine-tuned ASR wav2vec 2.0 / HuBERT - Fine-tuned ASR
------------------------------------- -------------------------------------
Wav2Vec2ASRBundle
~~~~~~~~~~~~~~~~~
.. autoclass:: Wav2Vec2ASRBundle .. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate :members: sample_rate
...@@ -80,6 +85,9 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR ...@@ -80,6 +85,9 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
.. automethod:: get_labels .. automethod:: get_labels
.. minigallery:: torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
:add-heading: Examples using ``Wav2Vec2ASRBundle``
:heading-level: ~
WAV2VEC2_ASR_BASE_10M WAV2VEC2_ASR_BASE_10M
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
torchaudio.prototype.emformer torchaudio.prototype.emformer
============================= =============================
.. py:module:: torchaudio.prototype.emformer
.. currentmodule:: torchaudio.prototype.emformer .. currentmodule:: torchaudio.prototype.emformer
Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_ Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
torchaudio.sox_effects torchaudio.sox_effects
====================== ======================
.. py:module:: torchaudio.sox_effects
.. currentmodule:: torchaudio.sox_effects .. currentmodule:: torchaudio.sox_effects
Resource initialization / shutdown Resource initialization / shutdown
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
torchaudio.transforms torchaudio.transforms
====================== ======================
.. py:module:: torchaudio.transforms
.. currentmodule:: torchaudio.transforms .. currentmodule:: torchaudio.transforms
Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential` Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential`
......
torchaudio.utils torchaudio.utils
================ ================
.. py:module:: torchaudio.utils
torchaudio.utils.sox_utils torchaudio.utils.sox_utils
~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Wav2Vec2 Tutorials
==================
"""
Forced Alignment with Wav2Vec2
==============================
**Author** `Moto Hira <moto@fb.com>`__
This tutorial shows how to align transcript to speech with
``torchaudio``, using CTC segmentation algorithm described in
`CTC-Segmentation of Large Corpora for German End-to-end Speech
Recognition <https://arxiv.org/abs/2007.09127>`__.
"""
######################################################################
# Overview
# --------
#
# The process of alignment looks like the following.
#
# 1. Estimate the frame-wise label probability from audio waveform
# 2. Generate the trellis matrix which represents the probability of
# labels aligned at time step.
# 3. Find the most likely path from the trellis matrix.
#
# In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
# acoustic feature extraction.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import os
from dataclasses import dataclass
import torch
import torchaudio
import requests
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = 'https://download.pytorch.org/torchaudio/test-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
SPEECH_FILE = 'speech.flac'
if not os.path.exists(SPEECH_FILE):
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
# Generate frame-wise label probability
# -------------------------------------
#
# The first step is to generate the label class porbability of each aduio
# frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
# :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
#
# ``torchaudio`` provides easy access to pretrained models with associated
# labels.
#
# .. note::
#
# In the subsequent sections, we will compute the probability in
# log-domain to avoid numerical instability. For this purpose, we
# normalize the ``emission`` with :py:func:`torch.log_softmax`.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
################################################################################
# Visualization
################################################################################
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
######################################################################
# Generate alignment probability (trellis)
# ----------------------------------------
#
# From the emission matrix, next we generate the trellis which represents
# the probability of transcript labels occur at each time frame.
#
# Trellis is 2D matrix with time axis and label axis. The label axis
# represents the transcript that we are aligning. In the following, we use
# :math:`t` to denote the index in time axis and :math:`j` to denote the
# index in label axis. :math:`c_j` represents the label at label index
# :math:`j`.
#
# To generate, the probability of time step :math:`t+1`, we look at the
# trellis from time step :math:`t` and emission at time step :math:`t+1`.
# There are two path to reach to time step :math:`t+1` with label
# :math:`c_{j+1}`. The first one is the case where the label was
# :math:`c_{j+1}` at :math:`t` and there was no label change from
# :math:`t` to :math:`t+1`. The other case is where the label was
# :math:`c_j` at :math:`t` and it transitioned to the next label
# :math:`c_{j+1}` at :math:`t+1`.
#
# The follwoing diagram illustrates this transition.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png
#
# Since we are looking for the most likely transitions, we take the more
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
#
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
#
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
# represents the probability of label :math:`c_j` at time step :math:`t`.
# :math:`repeat` represents the blank token from CTC formulation. (For the
# detail of CTC algorithm, please refer to the *Sequence Modeling with CTC*
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
#
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.full((num_frame+1, num_tokens+1), -float('inf'))
trellis[:, 0] = 0
for t in range(num_frame):
trellis[t+1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
)
return trellis
trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis[1:, 1:].T, origin='lower')
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.colorbar()
plt.show()
######################################################################
# In the above visualization, we can see that there is a trace of high
# probability crossing the matrix diagonally.
#
######################################################################
# Find the most likely path (backtracking)
# ----------------------------------------
#
# Once the trellis is generated, we will traverse it following the
# elements with high probability.
#
# We will start from the last label index with the time step of highest
# probability, then, we traverse back in time, picking stay
# (:math:`c_j \rightarrow c_j`) or transition
# (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition
# probability :math:`k_{t, j} p(t+1, c_{j+1})` or
# :math:`k_{t, j+1} p(t+1, repeat)`.
#
# Transition is done once the label reaches the beginning.
#
# The trellis matrix is used for path-finding, but for the final
# probability of each segment, we take the frame-wise probability from
# emission matrix.
#
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When refering to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when refering to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t-1, j] + emission[t-1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t-1, j-1] + emission[t-1, tokens[j-1]]
# 2. Store the path with frame-wise probability.
prob = emission[t-1, tokens[j-1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j-1, t-1, prob))
# 3. Update the token
if changed > stayed:
j -= 1
if j == 0:
break
else:
raise ValueError('Failed to align')
return path[::-1]
path = backtrack(trellis, emission, tokens)
print(path)
################################################################################
# Visualization
################################################################################
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float('nan')
plt.imshow(trellis_with_path[1:, 1:].T, origin='lower')
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
plt.show()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
# probability for the merged segments.
#
# Merge the labels
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(Segment(transcript[path[i1].token_index], path[i1].time_index, path[i2-1].time_index + 1, score))
i1 = i2
return segments
segments = merge_repeats(path)
for seg in segments:
print(seg)
################################################################################
# Visualization
################################################################################
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin='lower')
ax1.set_xticks([])
for i, seg in enumerate(segments):
if seg.label != '|':
ax1.annotate(seg.label, (seg.start + .7, i + 0.3), weight='bold')
ax1.annotate(f'{seg.score:.2f}', (seg.start - .3, i + 4.3))
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != '|':
xs.append((seg.end + seg.start) / 2 + .4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + .8, -0.07), weight='bold')
ax2.bar(xs, hs, width=ws, color='gray', alpha=0.5, edgecolor='black')
xs, hs = [], []
for p in path:
label = transcript[p.token_index]
if label != '|':
xs.append(p.time_index + 1)
hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color='black')
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(-0.1, 1.1)
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
# Then, finally, we segment the original audio into segmented audio and
# listen to them to see if the segmentation is correct.
#
# Merge words
def merge_words(segments, separator='|'):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = ''.join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2-1].end, score))
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
word_segments = merge_words(segments)
for word in word_segments:
print(word)
################################################################################
# Visualization
################################################################################
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower')
ax1.set_xticks([])
ax1.set_yticks([])
for word in word_segments:
ax1.axvline(word.start - 0.5)
ax1.axvline(word.end - 0.5)
for i, seg in enumerate(segments):
if seg.label != '|':
ax1.annotate(seg.label, (seg.start, i + 0.3))
ax1.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8)
# The original waveform
ratio = waveform.size(0) / (trellis.size(0) - 1)
ax2.plot(waveform)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color='red')
ax2.annotate(f'{word.score:.2f}', (x0, 0.8))
for seg in segments:
if seg.label != '|':
ax2.annotate(seg.label, (seg.start * ratio, 0.9))
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.set_xlabel('time [second]')
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
plot_alignments(trellis, segments, word_segments, waveform[0],)
plt.show()
# Generate the audio for each segment
print(transcript)
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
ratio = waveform.size(1) / (trellis.size(0) - 1)
for i, word in enumerate(word_segments):
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
filename = f"{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print(f"{word.label}: {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f}")
IPython.display.display(IPython.display.Audio(filename))
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to
# perform CTC segmentation for forced alignment.
#
"""
Speech Recognition with Wav2Vec2
================================
**Author**: `Moto Hira <moto@fb.com>`__
This tutorial shows how to perform speech recognition using using
pre-trained models from wav2vec 2.0
[`paper <https://arxiv.org/abs/2006.11477>`__].
"""
######################################################################
# Overview
# --------
#
# The process of speech recognition looks like the following.
#
# 1. Extract the acoustic features from audio waveform
#
# 2. Estimate the class of the acoustic features frame-by-frame
#
# 3. Generate hypothesis from the sequence of the class probabilities
#
# Torchaudio provides easy access to the pre-trained weights and
# associated information, such as the expected sample rate and class
# labels. They are bundled together and available under
# ``torchaudio.pipelines`` module.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import os
import torch
import torchaudio
import requests
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = "speech.wav"
if not os.path.exists(SPEECH_FILE):
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
# Creating a pipeline
# -------------------
#
# First, we will create a Wav2Vec2 model that performs the feature
# extraction and the classification.
#
# There are two types of Wav2Vec2 pre-trained weights available in
# torchaudio. The ones fine-tuned for ASR task, and the ones not
# fine-tuned.
#
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
# are firstly trained with audio only for representation learning, then
# fine-tuned for a specific task with additional labels.
#
# The pre-trained weights without fine-tuning can be fine-tuned
# for other downstream tasks as well, but this tutorial does not
# cover that.
#
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
#
# There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained.
#
# The bundle object provides the interface to instantiate model and other
# information. Sampling rate and the class labels are found as follow.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
######################################################################
# Model can be constructed as following. This process will automatically
# fetch the pre-trained weights and load it into the model.
#
model = bundle.get_model().to(device)
print(model.__class__)
######################################################################
# Loading data
# ------------
#
# We will use the speech data from `VOiCES
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
# Creative Commos BY 4.0.
#
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
######################################################################
# To load data, we use :py:func:`torchaudio.load`.
#
# If the sampling rate is different from what the pipeline expects, then
# we can use :py:func:`torchaudio.functional.resample` for resampling.
#
# .. note::
#
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
#
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
######################################################################
# Extracting acoustic features
# ----------------------------
#
# The next step is to extract acoustic features from the audio.
#
# .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
#
with torch.inference_mode():
features, _ = model.extract_features(waveform)
######################################################################
# The returned features is a list of tensors. Each tensor is the output of
# a transformer layer.
#
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
######################################################################
# Feature classification
# ----------------------
#
# Once the acoustic features are extracted, the next step is to classify
# them into a set of categories.
#
# Wav2Vec2 model provides method to perform the feature extraction and
# classification in one step.
#
with torch.inference_mode():
emission, _ = model(waveform)
######################################################################
# The output is in the form of logits. It is not in the form of
# probability.
#
# Let’s visualize this.
#
plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
print("Class labels:", bundle.get_labels())
######################################################################
# We can see that there are strong indications to certain labels across
# the time line.
#
######################################################################
# Generating transcripts
# ----------------------
#
# From the sequence of label probabilities, now we want to generate
# transcripts. The process to generate hypotheses is often called
# “decoding”.
#
# Decoding is more elaborate than simple classification because
# decoding at certain time step can be affected by surrounding
# observations.
#
# For example, take a word like ``night`` and ``knight``. Even if their
# prior probability distribution are differnt (in typical conversations,
# ``night`` would occur way more often than ``knight``), to accurately
# generate transcripts with ``knight``, such as ``a knight with a sword``,
# the decoding process has to postpone the final decision until it sees
# enough context.
#
# There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models.
#
# In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated.
#
# We start by defining greedy decoding algorithm.
#
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices])
######################################################################
# Now create the decoder object and decode the transcript.
#
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
######################################################################
# Let’s check the result and listen again to the audio.
#
print(transcript)
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
######################################################################
# The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
# The detail of CTC loss is explained
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In
# decoding, these are simply ignored.
#
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines.
#
# ::
#
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
# emission = model(waveforms, ...)
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment