"notebooks/vscode:/vscode.git/clone" did not exist on "d1e5700676a80450baf70df1769e1efd8f60efa0"
Commit b645c07b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update ctc forced alignment tutorial (#3529)

Summary:
- Simplify the step to generate token-level alignment

Pull Request resolved: https://github.com/pytorch/audio/pull/3529

Reviewed By: huangruizhe

Differential Revision: D48066787

Pulled By: mthrok

fbshipit-source-id: 452c243d278e508926a59894928e280fea76dcc6
parent 09aabcc1
......@@ -87,17 +87,20 @@ with torch.inference_mode():
emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1)
num_frames = emission.size(1)
######################################################################
#
def plot_emission(emission):
plt.imshow(emission.cpu().T)
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.tight_layout()
fig, ax = plt.subplots()
ax.imshow(emission.cpu().T)
ax.set_title("Frame-wise class probabilities")
ax.set_xlabel("Time")
ax.set_ylabel("Labels")
fig.tight_layout()
plot_emission(emission[0])
......@@ -114,9 +117,9 @@ for k, v in DICTIONARY.items():
######################################################################
# converting transcript to tokens is as simple as
tokens = [DICTIONARY[c] for c in TRANSCRIPT]
tokenized_transcript = [DICTIONARY[c] for c in TRANSCRIPT]
print(" ".join(str(t) for t in tokens))
print(" ".join(str(t) for t in tokenized_transcript))
######################################################################
# Computing frame-level alignments
......@@ -138,20 +141,20 @@ def align(emission, tokens):
blank=0,
)
scores = scores.exp() # convert back to probability
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
return alignments.tolist(), scores.tolist()
scores = scores.exp() # convert back to probability
return alignments, scores
frame_alignment, frame_scores = align(emission, tokens)
aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
######################################################################
# Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{labels[ali]}], {score:.2f}")
######################################################################
#
......@@ -177,116 +180,43 @@ for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I"
# 31: 0 [-], 0.98 # repeat (blank token)
# 32: 0 [-], 1.00 # repeat (blank token)
# 33: 1 [|], 0.85 # Start of "|" (word boundary)
# 34: 1 [|], 1.00 # repeat (same token)
# 35: 0 [-], 0.61 # repeat (blank token)
# 36: 8 [H], 1.00 # Start of "H"
# 37: 0 [-], 1.00 # repeat (blank token)
# 38: 4 [A], 1.00 # Start of "A"
# 39: 0 [-], 0.99 # repeat (blank token)
# 40: 11 [D], 0.92 # Start of "D"
# 41: 0 [-], 0.93 # repeat (blank token)
# 42: 1 [|], 0.98 # Start of "|"
# 43: 1 [|], 1.00 # repeat (same token)
# 44: 3 [T], 1.00 # Start of "T"
# 45: 3 [T], 0.90 # repeat (same token)
# 46: 8 [H], 1.00 # Start of "H"
# 47: 0 [-], 1.00 # repeat (blank token)
######################################################################
# Resolve blank and repeated tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 30: 7 [I], 1.00 # "I" starts and ends
# 31: 0 [-], 0.98 #
# 32: 0 [-], 1.00 #
# 33: 1 [|], 0.85 # "|" (word boundary) starts
# 34: 1 [|], 1.00 # "|" ends
# 35: 0 [-], 0.61 #
# 36: 8 [H], 1.00 # "H" starts and ends
# 37: 0 [-], 1.00 #
# 38: 4 [A], 1.00 # "A" starts and ends
# 39: 0 [-], 0.99 #
# 40: 11 [D], 0.92 # "D" starts and ends
# 41: 0 [-], 0.93 #
# 42: 1 [|], 0.98 # "|" starts
# 43: 1 [|], 1.00 # "|" ends
# 44: 3 [T], 1.00 # "T" starts
# 45: 3 [T], 0.90 # "T" ends
# 46: 8 [H], 1.00 # "H" starts and ends
# 47: 0 [-], 1.00 #
######################################################################
# Obtain token-level alignment
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to
# non-blank tokens.
#
# Each element contains the following
#
# - ``token_index``: the aligned token’s index **in the transcript**
# - ``time_index``: the current frame’s index in emission
# - ``score``: scores of the current frame.
#
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
# From the outputs ``alignment``, we compute the following ``Span`` object,
# which explains what token (in transcript) is present at what time span.
@dataclass
class Frame:
token_index: int
time_index: int
class TokenSpan:
index: int # index of token in transcript
start: int # start time (inclusive)
end: int # end time (exclusive)
score: float
######################################################################
#
def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert len(alignments) == len(scores)
token_index = -1
prev_hyp = 0
frames = []
for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
return frames
######################################################################
#
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments).
#
# Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging
# frame-level confidence scores.
#
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self):
def __len__(self) -> int:
return self.end - self.start
......@@ -294,32 +224,31 @@ class Segment:
#
def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
segments = []
while i1 < len(frames):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
frames[i1].time_index,
frames[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
######################################################################
#
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments:
print(seg)
token_spans = merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{TRANSCRIPT[s.index]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
######################################################################
# Visualization
......@@ -327,51 +256,37 @@ for seg in segments:
#
def plot_label_prob(segments, transcript):
def plot_scores(spans, scores, transcript):
fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
for p in frames:
label = transcript[p.token_index]
if label != "|":
xs.append(p.time_index + 1)
hs.append(p.score)
ax.bar(xs, hs, width=0.5, alpha=0.5)
span_xs, span_hs, span_ws = [], [], []
frame_xs, frame_hs = [], []
for span in spans:
token = transcript[span.index]
if token != "|":
span_xs.append((span.end + span.start) / 2 + 0.4)
span_hs.append(span.score)
span_ws.append(span.end - span.start)
ax.annotate(token, (span.start + 0.8, -0.07), weight="bold")
for t in range(span.start, span.end):
frame_xs.append(t + 1)
frame_hs.append(scores[t].item())
ax.bar(span_xs, span_hs, width=span_ws, color="gray", alpha=0.5, edgecolor="black")
ax.bar(frame_xs, frame_hs, width=0.5, alpha=0.5)
ax.set_ylim(-0.1, 1.1)
ax.grid(True, axis="y")
fig.tight_layout()
plot_label_prob(segments, TRANSCRIPT)
plot_scores(token_spans, alignment_scores, TRANSCRIPT)
######################################################################
# From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial
# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
#
######################################################################
# Obtain word-level alignments and confidence scores
# --------------------------------------------------
#
######################################################################
# Now let’s merge the token-level alignments and confidence scores to get
# word-level alignments and confidence scores. Then, finally, we verify
......@@ -380,32 +295,30 @@ plot_label_prob(segments, TRANSCRIPT)
# alignments and listening to them.
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
# Obtain word alignments from token alignments
def merge_words(transcript, segments, separator=" "):
def merge_words(token_spans, transcript, separator="|") -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
words = []
i1, i2, i3 = 0, 0, 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator:
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
if separator == "|":
# s is the number of separators (counted as a valid modeling unit) we've seen
s = len(words)
else:
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2
else:
i2 += 1
i3 += 1
i = 0
for j, span in enumerate(token_spans):
if transcript[span.index] == separator:
words.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j + 1
if i < len(token_spans):
words.append(WordSpan(token_spans[i:], _score(token_spans[i:])))
return words
word_segments = merge_words(TRANSCRIPT, segments, "|")
word_spans = merge_words(token_spans, TRANSCRIPT)
######################################################################
......@@ -414,38 +327,40 @@ word_segments = merge_words(TRANSCRIPT, segments, "|")
#
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
def plot_alignments(waveform, word_spans, num_frames, transcript, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
ratio = waveform.size(1) / sample_rate / num_frames
for w_span in word_spans:
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
ax.axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
# The original waveform
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
t0, t1 = ratio * word.start, ratio * word.end
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
for span in t_spans:
token = transcript[span.index]
ax.annotate(token, (span.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]")
ax.set_xlim([0, None])
fig.tight_layout()
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
######################################################################
def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / len(frame_alignment)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
t0 = word_span.token_spans[0].start
t1 = word_span.token_spans[-1].end
x0 = int(ratio * t0)
x1 = int(ratio * t1)
tokens = "".join(transcript[t.index] for t in word_span.token_spans)
print(f"{tokens} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
......@@ -459,47 +374,47 @@ IPython.display.Audio(SPEECH_FILE)
######################################################################
#
display_segment(0, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(1, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(2, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(3, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(4, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(5, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(6, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(7, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT)
######################################################################
#
display_segment(8, waveform, word_segments, frame_alignment)
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT)
######################################################################
......@@ -527,38 +442,53 @@ DICTIONARY["*"] = len(DICTIONARY)
# corresponding to the ``<star>`` token.
#
extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
emission = torch.cat((emission, extra_dim), 2)
star_dim = torch.zeros((1, num_frames, 1), device=device)
emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
plot_emission(emission[0])
######################################################################
# The following function combines all the processes, and compute
# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|")
plot_alignments(waveform, emission, segments, word_segments)
plt.xlim([0, None])
token_spans = merge_tokens(alignment, scores)
word_spans = merge_words(token_spans, transcript)
return word_spans
######################################################################
# **Original**
compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
transcript = "*|THIS|MOMENT"
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
######################################################################
#
preview_word(waveform, word_spans[1], num_frames, transcript)
######################################################################
#
preview_word(waveform, word_spans[2], num_frames, transcript)
######################################################################
#
######################################################################
# **Without <star> token**
......@@ -567,7 +497,9 @@ compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
transcript = "THIS|MOMENT"
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
######################################################################
# Conclusion
......
......@@ -25,12 +25,13 @@ print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from dataclasses import dataclass
######################################################################
# Preparation
# -----------
#
from dataclasses import dataclass
from typing import Dict, List
import IPython
import matplotlib.pyplot as plt
......@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000
@dataclass
class Frame:
token_index: int
time_index: int
class TokenSpan:
index: int # index of token in transcript
start: int # start time (inclusive)
end: int # end time (exclusive)
score: float
def __len__(self) -> int:
return self.end - self.start
######################################################################
#
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self):
return self.end - self.start
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
######################################################################
#
def align_emission_and_tokens(emission: torch.Tensor, tokens: List[int]):
device = emission.device
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
input_lengths = torch.tensor([emission.size(1)], device=device)
target_lengths = torch.tensor([targets.size(1)], device=device)
aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def compute_alignments(transcript, dictionary, emission):
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
def merge_words(token_spans: List[TokenSpan], transcript: List[str]) -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
word_spans = []
i = 0
for words in transcript:
j = i + len(words)
word_spans.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j
return word_spans
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
segments = []
while i1 < len(frames):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
frames[i1].time_index,
frames[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
# compue word alignments from token alignments
separator = " "
words = []
i1, i2, i3 = 0, 0, 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator:
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
def compute_alignments(emission: torch.Tensor, transcript: List[str], dictionary: Dict[str, int]):
tokens = [dictionary[c] for word in transcript for c in word]
aligned_tokens, scores = align_emission_and_tokens(emission, tokens)
token_spans = merge_tokens(aligned_tokens, scores)
word_spans = merge_words(token_spans, transcript)
return word_spans
######################################################################
#
# utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for w_span, chars in zip(word_spans, transcript):
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
axes[1].annotate(char, (span.start * ratio, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments.
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
def preview_word(waveform, word_span, num_frames, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
t0 = word_span.token_spans[0].start
t1 = word_span.token_spans[-1].end
x0 = int(ratio * t0)
x1 = int(ratio * t1)
print(f"{transcript} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
......@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE),
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
# Chinese
......@@ -423,59 +395,59 @@ waveform = waveform[0:1]
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[8], num_frames, transcript[8])
######################################################################
......@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
# Portuguese
......@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[8], num_frames, transcript[8])
######################################################################
# Italian
......@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
# Conclusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment