".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "5a6f4eba022a330e4cb9277ee204dee4eb06790c"
Commit cd80976e authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Make target_lengths/input_lengths in forced_align optional (#3533)

Summary:
Currently `torchaudio.functional.forced_align` function requires full information on input/target lengths.
When performing non-batched alignment, these can be inferred from the size of Tensor.

Pull Request resolved: https://github.com/pytorch/audio/pull/3533

Reviewed By: nateanl

Differential Revision: D48111041

Pulled By: mthrok

fbshipit-source-id: fbf07124d3959c5cc5533dcd86296851587082fb
parent b976c8f1
...@@ -2505,33 +2505,35 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor: ...@@ -2505,33 +2505,35 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
@fail_if_no_align @fail_if_no_align
def forced_align( def forced_align(
log_probs: torch.Tensor, log_probs: Tensor,
targets: torch.Tensor, targets: Tensor,
input_lengths: torch.Tensor, input_lengths: Optional[Tensor] = None,
target_lengths: torch.Tensor, target_lengths: Optional[Tensor] = None,
blank: int = 0, blank: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor]:
r"""Computes forced alignment given the emissions from a CTC-trained model and a target label. r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA .. devices:: CPU CUDA
.. properties:: TorchScript .. properties:: TorchScript
Args: Args:
log_probs (torch.Tensor): log probability of CTC emission output. log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank. `C` is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor of shape `(B, L)`, targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length. where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. input_lengths (Tensor or None, optional):
target_lengths (torch.Tensor): Lengths of the targets. 1-D Tensor of shape `(B,)`. Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns: Returns:
Tuple(torch.Tensor, torch.Tensor): Tuple(Tensor, Tensor):
torch.Tensor: Label for each time step in the alignment path computed using forced alignment. Tensor: Label for each time step in the alignment path computed using forced alignment.
torch.Tensor: Log probability scores of the labels for each time step. Tensor: Log probability scores of the labels for each time step.
Note: Note:
The sequence length of `log_probs` must satisfy: The sequence length of `log_probs` must satisfy:
...@@ -2550,5 +2552,17 @@ def forced_align( ...@@ -2550,5 +2552,17 @@ def forced_align(
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.") raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]: if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension") raise ValueError("targets values must be less than the CTC dimension")
if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)
# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank) paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores return paths, scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment