Commit b645c07b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update ctc forced alignment tutorial (#3529)

Summary:
- Simplify the step to generate token-level alignment

Pull Request resolved: https://github.com/pytorch/audio/pull/3529

Reviewed By: huangruizhe

Differential Revision: D48066787

Pulled By: mthrok

fbshipit-source-id: 452c243d278e508926a59894928e280fea76dcc6
parent 09aabcc1
...@@ -87,17 +87,20 @@ with torch.inference_mode(): ...@@ -87,17 +87,20 @@ with torch.inference_mode():
emission, _ = model(waveform.to(device)) emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1) emission = torch.log_softmax(emission, dim=-1)
num_frames = emission.size(1)
###################################################################### ######################################################################
# #
def plot_emission(emission): def plot_emission(emission):
plt.imshow(emission.cpu().T) fig, ax = plt.subplots()
plt.title("Frame-wise class probabilities") ax.imshow(emission.cpu().T)
plt.xlabel("Time") ax.set_title("Frame-wise class probabilities")
plt.ylabel("Labels") ax.set_xlabel("Time")
plt.tight_layout() ax.set_ylabel("Labels")
fig.tight_layout()
plot_emission(emission[0]) plot_emission(emission[0])
...@@ -114,9 +117,9 @@ for k, v in DICTIONARY.items(): ...@@ -114,9 +117,9 @@ for k, v in DICTIONARY.items():
###################################################################### ######################################################################
# converting transcript to tokens is as simple as # converting transcript to tokens is as simple as
tokens = [DICTIONARY[c] for c in TRANSCRIPT] tokenized_transcript = [DICTIONARY[c] for c in TRANSCRIPT]
print(" ".join(str(t) for t in tokens)) print(" ".join(str(t) for t in tokenized_transcript))
###################################################################### ######################################################################
# Computing frame-level alignments # Computing frame-level alignments
...@@ -138,20 +141,20 @@ def align(emission, tokens): ...@@ -138,20 +141,20 @@ def align(emission, tokens):
blank=0, blank=0,
) )
scores = scores.exp() # convert back to probability
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
return alignments.tolist(), scores.tolist() scores = scores.exp() # convert back to probability
return alignments, scores
frame_alignment, frame_scores = align(emission, tokens) aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
###################################################################### ######################################################################
# Now let's look at the output. # Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of # Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform. # emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)): for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}") print(f"{i:3d}:\t{ali:2d} [{labels[ali]}], {score:.2f}")
###################################################################### ######################################################################
# #
...@@ -177,116 +180,43 @@ for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)): ...@@ -177,116 +180,43 @@ for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
# .. code-block:: # .. code-block::
# #
# 29: 0 [-], 1.00 # 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I" # 30: 7 [I], 1.00 # "I" starts and ends
# 31: 0 [-], 0.98 # repeat (blank token) # 31: 0 [-], 0.98 #
# 32: 0 [-], 1.00 # repeat (blank token) # 32: 0 [-], 1.00 #
# 33: 1 [|], 0.85 # Start of "|" (word boundary) # 33: 1 [|], 0.85 # "|" (word boundary) starts
# 34: 1 [|], 1.00 # repeat (same token) # 34: 1 [|], 1.00 # "|" ends
# 35: 0 [-], 0.61 # repeat (blank token) # 35: 0 [-], 0.61 #
# 36: 8 [H], 1.00 # Start of "H" # 36: 8 [H], 1.00 # "H" starts and ends
# 37: 0 [-], 1.00 # repeat (blank token) # 37: 0 [-], 1.00 #
# 38: 4 [A], 1.00 # Start of "A" # 38: 4 [A], 1.00 # "A" starts and ends
# 39: 0 [-], 0.99 # repeat (blank token) # 39: 0 [-], 0.99 #
# 40: 11 [D], 0.92 # Start of "D" # 40: 11 [D], 0.92 # "D" starts and ends
# 41: 0 [-], 0.93 # repeat (blank token) # 41: 0 [-], 0.93 #
# 42: 1 [|], 0.98 # Start of "|" # 42: 1 [|], 0.98 # "|" starts
# 43: 1 [|], 1.00 # repeat (same token) # 43: 1 [|], 1.00 # "|" ends
# 44: 3 [T], 1.00 # Start of "T" # 44: 3 [T], 1.00 # "T" starts
# 45: 3 [T], 0.90 # repeat (same token) # 45: 3 [T], 0.90 # "T" ends
# 46: 8 [H], 1.00 # Start of "H" # 46: 8 [H], 1.00 # "H" starts and ends
# 47: 0 [-], 1.00 # repeat (blank token) # 47: 0 [-], 1.00 #
###################################################################### ######################################################################
# Resolve blank and repeated tokens # Obtain token-level alignment
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Next step is to resolve the repetation. So that what alignment represents # Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments. # do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a # From the outputs ``alignment``, we compute the following ``Span`` object,
# list called ``frames`` storing information of all frames aligned to # which explains what token (in transcript) is present at what time span.
# non-blank tokens.
#
# Each element contains the following
#
# - ``token_index``: the aligned token’s index **in the transcript**
# - ``time_index``: the current frame’s index in emission
# - ``score``: scores of the current frame.
#
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass @dataclass
class Frame: class TokenSpan:
token_index: int index: int # index of token in transcript
time_index: int start: int # start time (inclusive)
end: int # end time (exclusive)
score: float score: float
def __len__(self) -> int:
######################################################################
#
def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert len(alignments) == len(scores)
token_index = -1
prev_hyp = 0
frames = []
for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
return frames
######################################################################
#
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments).
#
# Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging
# frame-level confidence scores.
#
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self):
return self.end - self.start return self.end - self.start
...@@ -294,32 +224,31 @@ class Segment: ...@@ -294,32 +224,31 @@ class Segment:
# #
def merge_repeats(frames, transcript): def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
transcript_nospace = transcript.replace(" ", "") prev_token = blank
i1, i2 = 0, 0 i = start = -1
segments = [] spans = []
while i1 < len(frames): for t, token in enumerate(tokens):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index: if token != prev_token:
i2 += 1 if prev_token != blank:
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1) spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
segments.append( if token != blank:
Segment( i += 1
transcript_nospace[frames[i1].token_index], start = t
frames[i1].time_index, prev_token = token
frames[i2 - 1].time_index + 1, if prev_token != blank:
score, spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
) return spans
)
i1 = i2
return segments
###################################################################### ######################################################################
# #
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments:
print(seg)
token_spans = merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{TRANSCRIPT[s.index]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
###################################################################### ######################################################################
# Visualization # Visualization
...@@ -327,51 +256,37 @@ for seg in segments: ...@@ -327,51 +256,37 @@ for seg in segments:
# #
def plot_label_prob(segments, transcript): def plot_scores(spans, scores, transcript):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores") ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], [] span_xs, span_hs, span_ws = [], [], []
for seg in segments: frame_xs, frame_hs = [], []
if seg.label != "|": for span in spans:
xs.append((seg.end + seg.start) / 2 + 0.4) token = transcript[span.index]
hs.append(seg.score) if token != "|":
ws.append(seg.end - seg.start) span_xs.append((span.end + span.start) / 2 + 0.4)
ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") span_hs.append(span.score)
ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") span_ws.append(span.end - span.start)
ax.annotate(token, (span.start + 0.8, -0.07), weight="bold")
xs, hs = [], [] for t in range(span.start, span.end):
for p in frames: frame_xs.append(t + 1)
label = transcript[p.token_index] frame_hs.append(scores[t].item())
if label != "|": ax.bar(span_xs, span_hs, width=span_ws, color="gray", alpha=0.5, edgecolor="black")
xs.append(p.time_index + 1) ax.bar(frame_xs, frame_hs, width=0.5, alpha=0.5)
hs.append(p.score)
ax.bar(xs, hs, width=0.5, alpha=0.5)
ax.set_ylim(-0.1, 1.1) ax.set_ylim(-0.1, 1.1)
ax.grid(True, axis="y") ax.grid(True, axis="y")
fig.tight_layout() fig.tight_layout()
plot_label_prob(segments, TRANSCRIPT) plot_scores(token_spans, alignment_scores, TRANSCRIPT)
######################################################################
# From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial
# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
#
###################################################################### ######################################################################
# Obtain word-level alignments and confidence scores # Obtain word-level alignments and confidence scores
# -------------------------------------------------- # --------------------------------------------------
# #
###################################################################### ######################################################################
# Now let’s merge the token-level alignments and confidence scores to get # Now let’s merge the token-level alignments and confidence scores to get
# word-level alignments and confidence scores. Then, finally, we verify # word-level alignments and confidence scores. Then, finally, we verify
...@@ -380,32 +295,30 @@ plot_label_prob(segments, TRANSCRIPT) ...@@ -380,32 +295,30 @@ plot_label_prob(segments, TRANSCRIPT)
# alignments and listening to them. # alignments and listening to them.
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
# Obtain word alignments from token alignments # Obtain word alignments from token alignments
def merge_words(transcript, segments, separator=" "): def merge_words(token_spans, transcript, separator="|") -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
words = [] words = []
i1, i2, i3 = 0, 0, 0 i = 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator: for j, span in enumerate(token_spans):
if i1 != i2: if transcript[span.index] == separator:
if i3 == len(transcript) - 1: words.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i2 += 1 i = j + 1
if separator == "|": if i < len(token_spans):
# s is the number of separators (counted as a valid modeling unit) we've seen words.append(WordSpan(token_spans[i:], _score(token_spans[i:])))
s = len(words)
else:
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2
else:
i2 += 1
i3 += 1
return words return words
word_segments = merge_words(TRANSCRIPT, segments, "|") word_spans = merge_words(token_spans, TRANSCRIPT)
###################################################################### ######################################################################
...@@ -414,38 +327,40 @@ word_segments = merge_words(TRANSCRIPT, segments, "|") ...@@ -414,38 +327,40 @@ word_segments = merge_words(TRANSCRIPT, segments, "|")
# #
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate): def plot_alignments(waveform, word_spans, num_frames, transcript, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate) ax.specgram(waveform[0], Fs=sample_rate)
ratio = waveform.size(1) / sample_rate / num_frames
for w_span in word_spans:
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
ax.axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
# The original waveform for span in t_spans:
ratio = waveform.size(1) / sample_rate / emission.size(1) token = transcript[span.index]
for word in word_segments: ax.annotate(token, (span.start * ratio, sample_rate * 0.53), annotation_clip=False)
t0, t1 = ratio * word.start, ratio * word.end
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]") ax.set_xlabel("time [second]")
ax.set_xlim([0, None])
fig.tight_layout() fig.tight_layout()
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
###################################################################### ######################################################################
def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate): def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / len(frame_alignment) ratio = waveform.size(1) / num_frames
word = word_segments[i] t0 = word_span.token_spans[0].start
x0 = int(ratio * word.start) t1 = word_span.token_spans[-1].end
x1 = int(ratio * word.end) x0 = int(ratio * t0)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") x1 = int(ratio * t1)
tokens = "".join(transcript[t.index] for t in word_span.token_spans)
print(f"{tokens} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1] segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
...@@ -459,47 +374,47 @@ IPython.display.Audio(SPEECH_FILE) ...@@ -459,47 +374,47 @@ IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, frame_alignment) preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT)
###################################################################### ######################################################################
...@@ -527,38 +442,53 @@ DICTIONARY["*"] = len(DICTIONARY) ...@@ -527,38 +442,53 @@ DICTIONARY["*"] = len(DICTIONARY)
# corresponding to the ``<star>`` token. # corresponding to the ``<star>`` token.
# #
extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device) star_dim = torch.zeros((1, num_frames, 1), device=device)
emission = torch.cat((emission, extra_dim), 2) emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2] assert len(DICTIONARY) == emission.shape[2]
plot_emission(emission[0])
###################################################################### ######################################################################
# The following function combines all the processes, and compute # The following function combines all the processes, and compute
# word segments from emission in one-go. # word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform): def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[c] for c in transcript] tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens) alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores) token_spans = merge_tokens(alignment, scores)
segments = merge_repeats(frames, transcript) word_spans = merge_words(token_spans, transcript)
word_segments = merge_words(transcript, segments, "|") return word_spans
plot_alignments(waveform, emission, segments, word_segments)
plt.xlim([0, None])
###################################################################### ######################################################################
# **Original** # **Original**
compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform) word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
###################################################################### ######################################################################
# **With <star> token** # **With <star> token**
# #
# Now we replace the first part of the transcript with the ``<star>`` token. # Now we replace the first part of the transcript with the ``<star>`` token.
compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform) transcript = "*|THIS|MOMENT"
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
######################################################################
#
preview_word(waveform, word_spans[1], num_frames, transcript)
######################################################################
#
preview_word(waveform, word_spans[2], num_frames, transcript)
######################################################################
#
###################################################################### ######################################################################
# **Without <star> token** # **Without <star> token**
...@@ -567,7 +497,9 @@ compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform) ...@@ -567,7 +497,9 @@ compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
# without using ``<star>`` token. # without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors. # It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform) transcript = "THIS|MOMENT"
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
###################################################################### ######################################################################
# Conclusion # Conclusion
......
...@@ -25,12 +25,13 @@ print(torchaudio.__version__) ...@@ -25,12 +25,13 @@ print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
from dataclasses import dataclass
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
from typing import Dict, List
from dataclasses import dataclass
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000 ...@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000
@dataclass @dataclass
class Frame: class TokenSpan:
token_index: int index: int # index of token in transcript
time_index: int start: int # start time (inclusive)
end: int # end time (exclusive)
score: float score: float
def __len__(self) -> int:
return self.end - self.start
###################################################################### ######################################################################
# #
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self): @dataclass
return self.end - self.start class WordSpan:
token_spans: List[TokenSpan]
score: float
###################################################################### ######################################################################
# #
def align_emission_and_tokens(emission: torch.Tensor, tokens: List[int]):
device = emission.device
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
input_lengths = torch.tensor([emission.size(1)], device=device)
target_lengths = torch.tensor([targets.size(1)], device=device)
aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def compute_alignments(transcript, dictionary, emission): scores = scores.exp() # convert back to probability
tokens = [dictionary[c] for c in transcript.replace(" ", "")] aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0) def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
def merge_words(token_spans: List[TokenSpan], transcript: List[str]) -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
word_spans = []
i = 0
for words in transcript:
j = i + len(words)
word_spans.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j
return word_spans
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist() def compute_alignments(emission: torch.Tensor, transcript: List[str], dictionary: Dict[str, int]):
tokens = [dictionary[c] for word in transcript for c in word]
assert len(alignment) == len(scores) == emission.size(1) aligned_tokens, scores = align_emission_and_tokens(emission, tokens)
token_spans = merge_tokens(aligned_tokens, scores)
token_index = -1 word_spans = merge_words(token_spans, transcript)
prev_hyp = 0 return word_spans
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
segments = []
while i1 < len(frames):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
frames[i1].time_index,
frames[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
# compue word alignments from token alignments
separator = " "
words = []
i1, i2, i3 = 0, 0, 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator:
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
###################################################################### ######################################################################
# #
# utility function for plotting word alignments # utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE): def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots() ratio = waveform.size(1) / emission.size(1) / sample_rate
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim() fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
ratio = waveform.size(1) / sample_rate / emission.size(1) axes[0].set_title("Emission")
for word in word_segments: axes[0].set_xticks([])
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
axes[1].specgram(waveform[0], Fs=sample_rate)
for w_span, chars in zip(word_spans, transcript):
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
axes[1].annotate(char, (span.start * ratio, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate) return IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
# utility function for playing audio segments.
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE): def preview_word(waveform, word_span, num_frames, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
word = word_segments[i] t0 = word_span.token_spans[0].start
x0 = int(ratio * word.start) t1 = word_span.token_spans[-1].end
x1 = int(ratio * word.end) x0 = int(ratio * t0)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") x1 = int(ratio * t1)
print(f"{transcript} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1] segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
...@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), ...@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE),
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# Chinese # Chinese
...@@ -423,59 +395,59 @@ waveform = waveform[0:1] ...@@ -423,59 +395,59 @@ waveform = waveform[0:1]
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) preview_word(waveform, word_spans[8], num_frames, transcript[8])
###################################################################### ######################################################################
...@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE)) ...@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# Portuguese # Portuguese
...@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr ...@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) preview_word(waveform, word_spans[8], num_frames, transcript[8])
###################################################################### ######################################################################
# Italian # Italian
...@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE)) ...@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# Conclusion # Conclusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment