Commit 90143e96 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Move alignment code to separate submodule (#3536)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3536

Reviewed By: huangruizhe

Differential Revision: D48120170

Pulled By: mthrok

fbshipit-source-id: dec7575db07734490099b35a8bfc854252952c6e
parent 5e211d66
from ._alignment import forced_align, merge_tokens, TokenSpan
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
band_biquad, band_biquad,
...@@ -35,7 +36,6 @@ from .functional import ( ...@@ -35,7 +36,6 @@ from .functional import (
detect_pitch_frequency, detect_pitch_frequency,
edit_distance, edit_distance,
fftconvolve, fftconvolve,
forced_align,
griffinlim, griffinlim,
inverse_spectrogram, inverse_spectrogram,
linear_fbanks, linear_fbanks,
...@@ -43,7 +43,6 @@ from .functional import ( ...@@ -43,7 +43,6 @@ from .functional import (
mask_along_axis, mask_along_axis,
mask_along_axis_iid, mask_along_axis_iid,
melscale_fbanks, melscale_fbanks,
merge_tokens,
mu_law_decoding, mu_law_decoding,
mu_law_encoding, mu_law_encoding,
mvdr_weights_rtf, mvdr_weights_rtf,
...@@ -60,7 +59,6 @@ from .functional import ( ...@@ -60,7 +59,6 @@ from .functional import (
spectral_centroid, spectral_centroid,
spectrogram, spectrogram,
speed, speed,
TokenSpan,
) )
__all__ = [ __all__ = [
......
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torchaudio._extension import fail_if_no_align
__all__ = []
@fail_if_no_align
def forced_align(
log_probs: Tensor,
targets: Tensor,
input_lengths: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
blank: int = 0,
) -> Tuple[Tensor, Tensor]:
r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(Tensor, Tensor):
Tensor: Label for each time step in the alignment path computed using forced alignment.
Tensor: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
.. math::
L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size==1``.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")
if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)
# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores
@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""
token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""
def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start
def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")
t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans
...@@ -4,13 +4,11 @@ import math ...@@ -4,13 +4,11 @@ import math
import tempfile import tempfile
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torchaudio._extension import fail_if_no_align
from torchaudio._internal.module_utils import deprecated from torchaudio._internal.module_utils import deprecated
from .filtering import highpass_biquad, treble_biquad from .filtering import highpass_biquad, treble_biquad
...@@ -53,9 +51,6 @@ __all__ = [ ...@@ -53,9 +51,6 @@ __all__ = [
"speed", "speed",
"preemphasis", "preemphasis",
"deemphasis", "deemphasis",
"forced_align",
"TokenSpan",
"merge_tokens",
] ]
...@@ -2504,126 +2499,3 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor: ...@@ -2504,126 +2499,3 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device) a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device) b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs) return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
@fail_if_no_align
def forced_align(
log_probs: Tensor,
targets: Tensor,
input_lengths: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
blank: int = 0,
) -> Tuple[Tensor, Tensor]:
r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(Tensor, Tensor):
Tensor: Label for each time step in the alignment path computed using forced alignment.
Tensor: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
.. math::
L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size==1``.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")
if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)
# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores
@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""
token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""
def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start
def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")
t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment