Commit 30668afb authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add merge_tokens / TokenSpan (#3535)

Summary:
This commit adds `merge_tokens` function which removes repeated tokens from CTC token sequences returned from `forced_align`.

Resolving repeated tokens is a necessary step and almost universal, thus it makes sense to have such helper function in torchaudio.

Pull Request resolved: https://github.com/pytorch/audio/pull/3535

Reviewed By: huangruizhe

Differential Revision: D48111202

Pulled By: mthrok

fbshipit-source-id: 25354bfa210aa5c03f8c1d3e201f253ca3761b24
parent cd80976e
...@@ -32,7 +32,16 @@ Utility ...@@ -32,7 +32,16 @@ Utility
preemphasis preemphasis
deemphasis deemphasis
speed speed
Forced Alignment
----------------
.. autosummary::
:toctree: generated
:nosignatures:
forced_align forced_align
merge_tokens
TokenSpan
Filtering Filtering
......
...@@ -1220,6 +1220,68 @@ class Functional(TestBaseMixin): ...@@ -1220,6 +1220,68 @@ class Functional(TestBaseMixin):
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"): with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
def _assert_tokens(self, first, second):
assert len(first) == len(second)
for f, s in zip(first, second):
self.assertEqual(f.token, s.token)
self.assertEqual(f.score, s.score)
self.assertEqual(f.start, s.start)
self.assertEqual(f.end, s.end)
@parameterized.expand(
[
([], [], []),
([F.TokenSpan(1, 0, 1, 1.0)], [1], [1.0]),
([F.TokenSpan(1, 0, 2, 0.5)], [1, 1], [0.4, 0.6]),
([F.TokenSpan(1, 0, 3, 0.6)], [1, 1, 1], [0.5, 0.6, 0.7]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 1, 2, 0.9)], [1, 2], [0.8, 0.9]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 1, 3, 0.5)], [1, 2, 2], [1.0, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(1, 2, 3, 1.0)], [1, 0, 1], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 2, 3, 1.0)], [1, 0, 2], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 2, 4, 0.5)], [1, 0, 1, 1], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 2, 4, 0.5)], [1, 0, 2, 2], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 4, 0.4)], [1, 0, 0, 1], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 4, 0.4)], [1, 0, 0, 2], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 5, 0.5)], [1, 0, 0, 1, 1], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 5, 0.5)], [1, 0, 0, 2, 2], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 2, 3, 0.5)], [1, 1, 2], [1.0, 0.8, 0.5]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 4, 0.7)], [1, 1, 0, 1], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 4, 0.7)], [1, 1, 0, 2], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 5, 0.4)], [1, 1, 0, 1, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 5, 0.4)], [1, 1, 0, 2, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 5, 0.3)], [1, 1, 0, 0, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 5, 0.3)], [1, 1, 0, 0, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 6, 0.2)],
[1, 1, 0, 0, 1, 1],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 6, 0.2)],
[1, 1, 0, 0, 2, 2],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
]
)
def test_merge_repeated_tokens(self, expected, tokens, scores):
scores_ = torch.tensor(scores, dtype=torch.float32, device=self.device)
tokens_ = torch.tensor(tokens, dtype=torch.int64, device=self.device)
spans = F.merge_tokens(tokens_, scores_, blank=0)
print(tokens_, scores_)
self._assert_tokens(spans, expected)
# Append blanks at the beginning and at the end.
for num_prefix, num_suffix in itertools.product([0, 1, 2], repeat=2):
tokens_ = ([0] * num_prefix) + tokens + ([0] * num_suffix)
scores_ = ([0.1] * num_prefix) + scores + ([0.1] * num_suffix)
tokens_ = torch.tensor(tokens_, dtype=torch.int64, device=self.device)
scores_ = torch.tensor(scores_, dtype=torch.float32, device=self.device)
expected_ = [F.TokenSpan(s.token, s.start + num_prefix, s.end + num_prefix, s.score) for s in expected]
print(tokens_, scores_)
spans = F.merge_tokens(tokens_, scores_, blank=0)
self._assert_tokens(spans, expected_)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
......
...@@ -43,6 +43,7 @@ from .functional import ( ...@@ -43,6 +43,7 @@ from .functional import (
mask_along_axis, mask_along_axis,
mask_along_axis_iid, mask_along_axis_iid,
melscale_fbanks, melscale_fbanks,
merge_tokens,
mu_law_decoding, mu_law_decoding,
mu_law_encoding, mu_law_encoding,
mvdr_weights_rtf, mvdr_weights_rtf,
...@@ -59,6 +60,7 @@ from .functional import ( ...@@ -59,6 +60,7 @@ from .functional import (
spectral_centroid, spectral_centroid,
spectrogram, spectrogram,
speed, speed,
TokenSpan,
) )
__all__ = [ __all__ = [
...@@ -94,6 +96,8 @@ __all__ = [ ...@@ -94,6 +96,8 @@ __all__ = [
"filtfilt", "filtfilt",
"flanger", "flanger",
"forced_align", "forced_align",
"merge_tokens",
"TokenSpan",
"gain", "gain",
"highpass_biquad", "highpass_biquad",
"lfilter", "lfilter",
......
...@@ -4,6 +4,7 @@ import math ...@@ -4,6 +4,7 @@ import math
import tempfile import tempfile
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -53,6 +54,8 @@ __all__ = [ ...@@ -53,6 +54,8 @@ __all__ = [
"preemphasis", "preemphasis",
"deemphasis", "deemphasis",
"forced_align", "forced_align",
"TokenSpan",
"merge_tokens",
] ]
...@@ -2566,3 +2569,61 @@ def forced_align( ...@@ -2566,3 +2569,61 @@ def forced_align(
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank) paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores return paths, scores
@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""
token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""
def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start
def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")
t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment