Unverified Commit a3363539 authored by moto's avatar moto Committed by GitHub
Browse files

Add Sphinx-gallery to doc (#1967)

parent 65dbf2d2
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E741,F405,W503,W504,F999
exclude = build,docs/source,_ext,third_party
exclude = build,docs/source,_ext,third_party,examples/gallery
......@@ -69,6 +69,8 @@ instance/
# Sphinx documentation
docs/_build/
docs/src/
docs/source/auto_examples
docs/source/gen_modules
# PyBuilder
target/
......
......@@ -4,3 +4,6 @@ sphinxcontrib.katex
sphinxcontrib.bibtex
matplotlib
pyparsing<3,>=2.0.2
sphinx_gallery
IPython
deep-phonemizer
......@@ -3,6 +3,8 @@
torchaudio.backend
==================
.. py:module:: torchaudio.backend
Overview
~~~~~~~~
......
......@@ -6,6 +6,8 @@ torchaudio.compliance.kaldi
.. currentmodule:: torchaudio.compliance.kaldi
.. py:module:: torchaudio.compliance.kaldi
The useful processing operations of kaldi_ can be performed with torchaudio.
Various functions with identical parameters are given so that torchaudio can
produce similar outputs.
......
......@@ -42,6 +42,7 @@ extensions = [
'sphinx.ext.viewcode',
'sphinxcontrib.katex',
'sphinxcontrib.bibtex',
'sphinx_gallery.gen_gallery',
]
# katex options
......@@ -58,6 +59,19 @@ delimiters : [
bibtex_bibfiles = ['refs.bib']
sphinx_gallery_conf = {
'examples_dirs': [
'../../examples/gallery/wav2vec2',
],
'gallery_dirs': [
'auto_examples/wav2vec2',
],
'filename_pattern': 'tutorial.py',
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('torchaudio',),
}
autosummary_generate = True
napoleon_use_ivar = True
napoleon_numpy_docstring = False
napoleon_google_docstring = True
......
torchaudio.datasets
====================
.. py:module:: torchaudio.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
and have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
......
......@@ -4,6 +4,8 @@
torchaudio.functional
=====================
.. py:module:: torchaudio.functional
.. currentmodule:: torchaudio.functional
Functions to perform common audio operations.
......
......@@ -42,6 +42,11 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
utils
prototype
.. toctree::
:maxdepth: 2
:caption: Tutorials
auto_examples/wav2vec2/index
.. toctree::
:maxdepth: 1
......
......@@ -4,6 +4,8 @@
torchaudio.kaldi_io
======================
.. py:module:: torchaudio.kaldi_io
.. currentmodule:: torchaudio.kaldi_io
To use this module, the dependency kaldi_io_ needs to be installed.
......
......@@ -4,6 +4,8 @@
torchaudio.models
=================
.. py:module:: torchaudio.models
.. currentmodule:: torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks.
......
......@@ -3,6 +3,8 @@ torchaudio.pipelines
.. currentmodule:: torchaudio.pipelines
.. py:module:: torchaudio.pipelines
The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights.
wav2vec 2.0 / HuBERT - Representation Learning
......@@ -73,6 +75,9 @@ HUBERT_XLARGE
wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
Wav2Vec2ASRBundle
~~~~~~~~~~~~~~~~~
.. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
......@@ -80,6 +85,9 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
.. automethod:: get_labels
.. minigallery:: torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
:add-heading: Examples using ``Wav2Vec2ASRBundle``
:heading-level: ~
WAV2VEC2_ASR_BASE_10M
~~~~~~~~~~~~~~~~~~~~~
......
......@@ -4,6 +4,8 @@
torchaudio.prototype.emformer
=============================
.. py:module:: torchaudio.prototype.emformer
.. currentmodule:: torchaudio.prototype.emformer
Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_
......
......@@ -3,6 +3,8 @@
torchaudio.sox_effects
======================
.. py:module:: torchaudio.sox_effects
.. currentmodule:: torchaudio.sox_effects
Resource initialization / shutdown
......
......@@ -4,6 +4,8 @@
torchaudio.transforms
======================
.. py:module:: torchaudio.transforms
.. currentmodule:: torchaudio.transforms
Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential`
......
torchaudio.utils
================
.. py:module:: torchaudio.utils
torchaudio.utils.sox_utils
~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Wav2Vec2 Tutorials
==================
"""
Forced Alignment with Wav2Vec2
==============================
**Author** `Moto Hira <moto@fb.com>`__
This tutorial shows how to align transcript to speech with
``torchaudio``, using CTC segmentation algorithm described in
`CTC-Segmentation of Large Corpora for German End-to-end Speech
Recognition <https://arxiv.org/abs/2007.09127>`__.
"""
######################################################################
# Overview
# --------
#
# The process of alignment looks like the following.
#
# 1. Estimate the frame-wise label probability from audio waveform
# 2. Generate the trellis matrix which represents the probability of
# labels aligned at time step.
# 3. Find the most likely path from the trellis matrix.
#
# In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
# acoustic feature extraction.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import os
from dataclasses import dataclass
import torch
import torchaudio
import requests
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = 'https://download.pytorch.org/torchaudio/test-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
SPEECH_FILE = 'speech.flac'
if not os.path.exists(SPEECH_FILE):
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
# Generate frame-wise label probability
# -------------------------------------
#
# The first step is to generate the label class porbability of each aduio
# frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
# :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
#
# ``torchaudio`` provides easy access to pretrained models with associated
# labels.
#
# .. note::
#
# In the subsequent sections, we will compute the probability in
# log-domain to avoid numerical instability. For this purpose, we
# normalize the ``emission`` with :py:func:`torch.log_softmax`.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
################################################################################
# Visualization
################################################################################
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
######################################################################
# Generate alignment probability (trellis)
# ----------------------------------------
#
# From the emission matrix, next we generate the trellis which represents
# the probability of transcript labels occur at each time frame.
#
# Trellis is 2D matrix with time axis and label axis. The label axis
# represents the transcript that we are aligning. In the following, we use
# :math:`t` to denote the index in time axis and :math:`j` to denote the
# index in label axis. :math:`c_j` represents the label at label index
# :math:`j`.
#
# To generate, the probability of time step :math:`t+1`, we look at the
# trellis from time step :math:`t` and emission at time step :math:`t+1`.
# There are two path to reach to time step :math:`t+1` with label
# :math:`c_{j+1}`. The first one is the case where the label was
# :math:`c_{j+1}` at :math:`t` and there was no label change from
# :math:`t` to :math:`t+1`. The other case is where the label was
# :math:`c_j` at :math:`t` and it transitioned to the next label
# :math:`c_{j+1}` at :math:`t+1`.
#
# The follwoing diagram illustrates this transition.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png
#
# Since we are looking for the most likely transitions, we take the more
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
#
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
#
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
# represents the probability of label :math:`c_j` at time step :math:`t`.
# :math:`repeat` represents the blank token from CTC formulation. (For the
# detail of CTC algorithm, please refer to the *Sequence Modeling with CTC*
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
#
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.full((num_frame+1, num_tokens+1), -float('inf'))
trellis[:, 0] = 0
for t in range(num_frame):
trellis[t+1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
)
return trellis
trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis[1:, 1:].T, origin='lower')
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.colorbar()
plt.show()
######################################################################
# In the above visualization, we can see that there is a trace of high
# probability crossing the matrix diagonally.
#
######################################################################
# Find the most likely path (backtracking)
# ----------------------------------------
#
# Once the trellis is generated, we will traverse it following the
# elements with high probability.
#
# We will start from the last label index with the time step of highest
# probability, then, we traverse back in time, picking stay
# (:math:`c_j \rightarrow c_j`) or transition
# (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition
# probability :math:`k_{t, j} p(t+1, c_{j+1})` or
# :math:`k_{t, j+1} p(t+1, repeat)`.
#
# Transition is done once the label reaches the beginning.
#
# The trellis matrix is used for path-finding, but for the final
# probability of each segment, we take the frame-wise probability from
# emission matrix.
#
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When refering to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when refering to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t-1, j] + emission[t-1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t-1, j-1] + emission[t-1, tokens[j-1]]
# 2. Store the path with frame-wise probability.
prob = emission[t-1, tokens[j-1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j-1, t-1, prob))
# 3. Update the token
if changed > stayed:
j -= 1
if j == 0:
break
else:
raise ValueError('Failed to align')
return path[::-1]
path = backtrack(trellis, emission, tokens)
print(path)
################################################################################
# Visualization
################################################################################
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float('nan')
plt.imshow(trellis_with_path[1:, 1:].T, origin='lower')
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
plt.show()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
# probability for the merged segments.
#
# Merge the labels
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(Segment(transcript[path[i1].token_index], path[i1].time_index, path[i2-1].time_index + 1, score))
i1 = i2
return segments
segments = merge_repeats(path)
for seg in segments:
print(seg)
################################################################################
# Visualization
################################################################################
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin='lower')
ax1.set_xticks([])
for i, seg in enumerate(segments):
if seg.label != '|':
ax1.annotate(seg.label, (seg.start + .7, i + 0.3), weight='bold')
ax1.annotate(f'{seg.score:.2f}', (seg.start - .3, i + 4.3))
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != '|':
xs.append((seg.end + seg.start) / 2 + .4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + .8, -0.07), weight='bold')
ax2.bar(xs, hs, width=ws, color='gray', alpha=0.5, edgecolor='black')
xs, hs = [], []
for p in path:
label = transcript[p.token_index]
if label != '|':
xs.append(p.time_index + 1)
hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color='black')
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(-0.1, 1.1)
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
# Then, finally, we segment the original audio into segmented audio and
# listen to them to see if the segmentation is correct.
#
# Merge words
def merge_words(segments, separator='|'):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = ''.join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2-1].end, score))
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
word_segments = merge_words(segments)
for word in word_segments:
print(word)
################################################################################
# Visualization
################################################################################
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower')
ax1.set_xticks([])
ax1.set_yticks([])
for word in word_segments:
ax1.axvline(word.start - 0.5)
ax1.axvline(word.end - 0.5)
for i, seg in enumerate(segments):
if seg.label != '|':
ax1.annotate(seg.label, (seg.start, i + 0.3))
ax1.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8)
# The original waveform
ratio = waveform.size(0) / (trellis.size(0) - 1)
ax2.plot(waveform)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color='red')
ax2.annotate(f'{word.score:.2f}', (x0, 0.8))
for seg in segments:
if seg.label != '|':
ax2.annotate(seg.label, (seg.start * ratio, 0.9))
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.set_xlabel('time [second]')
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
plot_alignments(trellis, segments, word_segments, waveform[0],)
plt.show()
# Generate the audio for each segment
print(transcript)
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
ratio = waveform.size(1) / (trellis.size(0) - 1)
for i, word in enumerate(word_segments):
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
filename = f"{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print(f"{word.label}: {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f}")
IPython.display.display(IPython.display.Audio(filename))
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to
# perform CTC segmentation for forced alignment.
#
"""
Speech Recognition with Wav2Vec2
================================
**Author**: `Moto Hira <moto@fb.com>`__
This tutorial shows how to perform speech recognition using using
pre-trained models from wav2vec 2.0
[`paper <https://arxiv.org/abs/2006.11477>`__].
"""
######################################################################
# Overview
# --------
#
# The process of speech recognition looks like the following.
#
# 1. Extract the acoustic features from audio waveform
#
# 2. Estimate the class of the acoustic features frame-by-frame
#
# 3. Generate hypothesis from the sequence of the class probabilities
#
# Torchaudio provides easy access to the pre-trained weights and
# associated information, such as the expected sample rate and class
# labels. They are bundled together and available under
# ``torchaudio.pipelines`` module.
#
######################################################################
# Preparation
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
import os
import torch
import torchaudio
import requests
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = "speech.wav"
if not os.path.exists(SPEECH_FILE):
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
# Creating a pipeline
# -------------------
#
# First, we will create a Wav2Vec2 model that performs the feature
# extraction and the classification.
#
# There are two types of Wav2Vec2 pre-trained weights available in
# torchaudio. The ones fine-tuned for ASR task, and the ones not
# fine-tuned.
#
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
# are firstly trained with audio only for representation learning, then
# fine-tuned for a specific task with additional labels.
#
# The pre-trained weights without fine-tuning can be fine-tuned
# for other downstream tasks as well, but this tutorial does not
# cover that.
#
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
#
# There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained.
#
# The bundle object provides the interface to instantiate model and other
# information. Sampling rate and the class labels are found as follow.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
######################################################################
# Model can be constructed as following. This process will automatically
# fetch the pre-trained weights and load it into the model.
#
model = bundle.get_model().to(device)
print(model.__class__)
######################################################################
# Loading data
# ------------
#
# We will use the speech data from `VOiCES
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
# Creative Commos BY 4.0.
#
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
######################################################################
# To load data, we use :py:func:`torchaudio.load`.
#
# If the sampling rate is different from what the pipeline expects, then
# we can use :py:func:`torchaudio.functional.resample` for resampling.
#
# .. note::
#
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
#
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
######################################################################
# Extracting acoustic features
# ----------------------------
#
# The next step is to extract acoustic features from the audio.
#
# .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
#
with torch.inference_mode():
features, _ = model.extract_features(waveform)
######################################################################
# The returned features is a list of tensors. Each tensor is the output of
# a transformer layer.
#
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
######################################################################
# Feature classification
# ----------------------
#
# Once the acoustic features are extracted, the next step is to classify
# them into a set of categories.
#
# Wav2Vec2 model provides method to perform the feature extraction and
# classification in one step.
#
with torch.inference_mode():
emission, _ = model(waveform)
######################################################################
# The output is in the form of logits. It is not in the form of
# probability.
#
# Let’s visualize this.
#
plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
print("Class labels:", bundle.get_labels())
######################################################################
# We can see that there are strong indications to certain labels across
# the time line.
#
######################################################################
# Generating transcripts
# ----------------------
#
# From the sequence of label probabilities, now we want to generate
# transcripts. The process to generate hypotheses is often called
# “decoding”.
#
# Decoding is more elaborate than simple classification because
# decoding at certain time step can be affected by surrounding
# observations.
#
# For example, take a word like ``night`` and ``knight``. Even if their
# prior probability distribution are differnt (in typical conversations,
# ``night`` would occur way more often than ``knight``), to accurately
# generate transcripts with ``knight``, such as ``a knight with a sword``,
# the decoding process has to postpone the final decision until it sees
# enough context.
#
# There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models.
#
# In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated.
#
# We start by defining greedy decoding algorithm.
#
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices])
######################################################################
# Now create the decoder object and decode the transcript.
#
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
######################################################################
# Let’s check the result and listen again to the audio.
#
print(transcript)
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
######################################################################
# The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
# The detail of CTC loss is explained
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In
# decoding, these are simply ignored.
#
######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines.
#
# ::
#
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
# emission = model(waveforms, ...)
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment