"vscode:/vscode.git/clone" did not exist on "152c2b3a9403811198beb8fffc5070fc80d1273c"
Commit b645c07b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update ctc forced alignment tutorial (#3529)

Summary:
- Simplify the step to generate token-level alignment

Pull Request resolved: https://github.com/pytorch/audio/pull/3529

Reviewed By: huangruizhe

Differential Revision: D48066787

Pulled By: mthrok

fbshipit-source-id: 452c243d278e508926a59894928e280fea76dcc6
parent 09aabcc1
...@@ -25,12 +25,13 @@ print(torchaudio.__version__) ...@@ -25,12 +25,13 @@ print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
from dataclasses import dataclass
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
from typing import Dict, List
from dataclasses import dataclass
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000 ...@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000
@dataclass @dataclass
class Frame: class TokenSpan:
token_index: int index: int # index of token in transcript
time_index: int start: int # start time (inclusive)
end: int # end time (exclusive)
score: float score: float
def __len__(self) -> int:
return self.end - self.start
###################################################################### ######################################################################
# #
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self): @dataclass
return self.end - self.start class WordSpan:
token_spans: List[TokenSpan]
score: float
###################################################################### ######################################################################
# #
def align_emission_and_tokens(emission: torch.Tensor, tokens: List[int]):
device = emission.device
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
input_lengths = torch.tensor([emission.size(1)], device=device)
target_lengths = torch.tensor([targets.size(1)], device=device)
aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def compute_alignments(transcript, dictionary, emission): scores = scores.exp() # convert back to probability
tokens = [dictionary[c] for c in transcript.replace(" ", "")] aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0) def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
segments = []
while i1 < len(frames):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
frames[i1].time_index,
frames[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
# compue word alignments from token alignments def merge_words(token_spans: List[TokenSpan], transcript: List[str]) -> List[WordSpan]:
separator = " " def _score(t_spans):
words = [] return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
i1, i2, i3 = 0, 0, 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator:
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
word_spans = []
i = 0
for words in transcript:
j = i + len(words)
word_spans.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j
return word_spans
######################################################################
#
def compute_alignments(emission: torch.Tensor, transcript: List[str], dictionary: Dict[str, int]):
def plot_emission(emission): tokens = [dictionary[c] for word in transcript for c in word]
fig, ax = plt.subplots() aligned_tokens, scores = align_emission_and_tokens(emission, tokens)
ax.imshow(emission.T, aspect="auto") token_spans = merge_tokens(aligned_tokens, scores)
ax.set_title("Emission") word_spans = merge_words(token_spans, transcript)
fig.tight_layout() return word_spans
###################################################################### ######################################################################
# #
# utility function for plotting word alignments # utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE): def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots() ratio = waveform.size(1) / emission.size(1) / sample_rate
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim() fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
ratio = waveform.size(1) / sample_rate / emission.size(1) axes[0].set_title("Emission")
for word in word_segments: axes[0].set_xticks([])
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
axes[1].specgram(waveform[0], Fs=sample_rate)
for w_span, chars in zip(word_spans, transcript):
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
axes[1].annotate(char, (span.start * ratio, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate) return IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
# utility function for playing audio segments.
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE): def preview_word(waveform, word_span, num_frames, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
word = word_segments[i] t0 = word_span.token_spans[0].start
x0 = int(ratio * word.start) t1 = word_span.token_spans[-1].end
x1 = int(ratio * word.end) x0 = int(ratio * t0)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") x1 = int(ratio * t1)
print(f"{transcript} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1] segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
...@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), ...@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE),
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# Chinese # Chinese
...@@ -423,59 +395,59 @@ waveform = waveform[0:1] ...@@ -423,59 +395,59 @@ waveform = waveform[0:1]
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) preview_word(waveform, word_spans[8], num_frames, transcript[8])
###################################################################### ######################################################################
...@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE)) ...@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# Portuguese # Portuguese
...@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr ...@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# #
display_segment(6, waveform, word_segments, num_frames) preview_word(waveform, word_spans[6], num_frames, transcript[6])
###################################################################### ######################################################################
# #
display_segment(7, waveform, word_segments, num_frames) preview_word(waveform, word_spans[7], num_frames, transcript[7])
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) preview_word(waveform, word_spans[8], num_frames, transcript[8])
###################################################################### ######################################################################
# Italian # Italian
...@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE)) ...@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device)) emission = get_emission(waveform.to(device))
num_frames = emission.size(1) num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
segments, word_segments = compute_alignments(text_normalized, dictionary, emission) transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
display_segment(2, waveform, word_segments, num_frames) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
display_segment(3, waveform, word_segments, num_frames) preview_word(waveform, word_spans[3], num_frames, transcript[3])
###################################################################### ######################################################################
# #
display_segment(4, waveform, word_segments, num_frames) preview_word(waveform, word_spans[4], num_frames, transcript[4])
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) preview_word(waveform, word_spans[5], num_frames, transcript[5])
###################################################################### ######################################################################
# Conclusion # Conclusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment