"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "5f57b0ef4268a6bd9e8043d54c351a608a7e1bca"
Commit b645c07b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update ctc forced alignment tutorial (#3529)

Summary:
- Simplify the step to generate token-level alignment

Pull Request resolved: https://github.com/pytorch/audio/pull/3529

Reviewed By: huangruizhe

Differential Revision: D48066787

Pulled By: mthrok

fbshipit-source-id: 452c243d278e508926a59894928e280fea76dcc6
parent 09aabcc1
......@@ -25,12 +25,13 @@ print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from dataclasses import dataclass
######################################################################
# Preparation
# -----------
#
from dataclasses import dataclass
from typing import Dict, List
import IPython
import matplotlib.pyplot as plt
......@@ -54,145 +55,117 @@ SAMPLE_RATE = 16000
@dataclass
class Frame:
token_index: int
time_index: int
class TokenSpan:
index: int # index of token in transcript
start: int # start time (inclusive)
end: int # end time (exclusive)
score: float
def __len__(self) -> int:
return self.end - self.start
######################################################################
#
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
def __len__(self):
return self.end - self.start
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
######################################################################
#
def align_emission_and_tokens(emission: torch.Tensor, tokens: List[int]):
device = emission.device
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
input_lengths = torch.tensor([emission.size(1)], device=device)
target_lengths = torch.tensor([targets.size(1)], device=device)
aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def compute_alignments(transcript, dictionary, emission):
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
def merge_words(token_spans: List[TokenSpan], transcript: List[str]) -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
word_spans = []
i = 0
for words in transcript:
j = i + len(words)
word_spans.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j
return word_spans
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
segments = []
while i1 < len(frames):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
frames[i1].time_index,
frames[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
# compue word alignments from token alignments
separator = " "
words = []
i1, i2, i3 = 0, 0, 0
while i3 < len(transcript):
if i3 == len(transcript) - 1 or transcript[i3] == separator:
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
def compute_alignments(emission: torch.Tensor, transcript: List[str], dictionary: Dict[str, int]):
tokens = [dictionary[c] for word in transcript for c in word]
aligned_tokens, scores = align_emission_and_tokens(emission, tokens)
token_spans = merge_tokens(aligned_tokens, scores)
word_spans = merge_words(token_spans, transcript)
return word_spans
######################################################################
#
# utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for w_span, chars in zip(word_spans, transcript):
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
axes[1].annotate(char, (span.start * ratio, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments.
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
def preview_word(waveform, word_span, num_frames, transcript, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
print(f"{word.label} ({word.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
t0 = word_span.token_spans[0].start
t1 = word_span.token_spans[-1].end
x0 = int(ratio * t0)
x1 = int(ratio * t1)
print(f"{transcript} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
......@@ -345,55 +318,54 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE),
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
# Chinese
......@@ -423,59 +395,59 @@ waveform = waveform[0:1]
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[8], num_frames, transcript[8])
######################################################################
......@@ -497,54 +469,54 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
# Portuguese
......@@ -565,59 +537,60 @@ waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_fr
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[6], num_frames, transcript[6])
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[7], num_frames, transcript[7])
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[8], num_frames, transcript[8])
######################################################################
# Italian
......@@ -638,44 +611,44 @@ waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
transcript = text_normalized.split()
word_spans = compute_alignments(emission, transcript, dictionary)
plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[3], num_frames, transcript[3])
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[4], num_frames, transcript[4])
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
preview_word(waveform, word_spans[5], num_frames, transcript[5])
######################################################################
# Conclusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment