Commit e7935cff authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[audio][PR] Add forced_align function to torchaudio (#3348)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3348

 The pull request adds a CTC-based forced alignment function that supports both CPU and CUDA deviced. The function takes the CTC emissions and target labels as inputs and generates the corresponding labels for each frame.

Reviewed By: vineelpratap, xiaohui-zhang

Differential Revision: D45867265

fbshipit-source-id: 3e25b06bf9bc8bb1bdcdc08de7f4434d912154cb
parent 0db5ab25
...@@ -56,6 +56,7 @@ option(BUILD_SOX "Build libsox statically" ON) ...@@ -56,6 +56,7 @@ option(BUILD_SOX "Build libsox statically" ON)
option(BUILD_KALDI "Build kaldi statically" ON) option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RIR "Enable RIR simulation" ON) option(BUILD_RIR "Enable RIR simulation" ON)
option(BUILD_RNNT "Enable RNN transducer" ON) option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_ALIGN "Enable forced alignment" ON)
option(BUILD_CUDA_CTC_DECODER "Build CUCTC decoder" OFF) option(BUILD_CUDA_CTC_DECODER "Build CUCTC decoder" OFF)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_FFMPEG "Enable ffmpeg-based features" OFF) option(USE_FFMPEG "Enable ffmpeg-based features" OFF)
...@@ -80,6 +81,14 @@ endif() ...@@ -80,6 +81,14 @@ endif()
if(USE_CUDA) if(USE_CUDA)
enable_language(CUDA) enable_language(CUDA)
set(
CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} \
-DCUDA_HAS_FP16=1 \
-D__CUDA_NO_HALF_OPERATORS__ \
-D__CUDA_NO_HALF_CONVERSIONS__ \
-D__CUDA_NO_HALF2_OPERATORS__"
)
endif() endif()
include(cmake/TorchAudioHelper.cmake) include(cmake/TorchAudioHelper.cmake)
...@@ -108,7 +117,16 @@ if(MSVC) ...@@ -108,7 +117,16 @@ if(MSVC)
unsigned_compare_with_zero unsigned_compare_with_zero
declared_but_not_referenced declared_but_not_referenced
bad_friend_decl) bad_friend_decl)
string(APPEND CMAKE_CUDA_FLAGS " -Xcudafe --diag_suppress=${diag}") string(
APPEND
CMAKE_CUDA_FLAGS
" -Xcudafe \
--diag_suppress=${diag} \
-DCUDA_HAS_FP16=1 \
-D__CUDA_NO_HALF_OPERATORS__ \
-D__CUDA_NO_HALF_CONVERSIONS__ \
-D__CUDA_NO_HALF2_OPERATORS__"
)
endforeach() endforeach()
CUDA_CONVERT_FLAGS(torch_cpu) CUDA_CONVERT_FLAGS(torch_cpu)
if(TARGET torch_cuda) if(TARGET torch_cuda)
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_impl import Functional from .functional_impl import Functional, FunctionalCUDAOnly
@skipIfNoCuda @skipIfNoCuda
...@@ -20,3 +20,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): ...@@ -20,3 +20,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestLFilterFloat64(Functional, PytorchTestCase): class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device("cuda") device = torch.device("cuda")
@skipIfNoCuda
class TestFunctionalCUDAOnlyFloat32(FunctionalCUDAOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class TestFunctionalCUDAOnlyFloat64(FunctionalCUDAOnly, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
...@@ -1114,6 +1114,105 @@ class Functional(TestBaseMixin): ...@@ -1114,6 +1114,105 @@ class Functional(TestBaseMixin):
deemphasized = F.deemphasis(preemphasized, coeff=coeff) deemphasized = F.deemphasis(preemphasized, coeff=coeff)
self.assertEqual(deemphasized, waveform) self.assertEqual(deemphasized, waveform)
@parameterized.expand(
[
([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32),
([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32),
([3, 3, 3], [3, 5, 3, 5, 3], torch.int64),
([0, 1, 2], [0, 1, 1, 1, 2], torch.int64),
]
)
def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor(
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
],
dtype=self.dtype,
device=self.device,
)
blank = 5
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor(
[torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])],
dtype=emission.dtype,
device=self.device,
)
log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((targets.shape[0]))
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device)
blank = 5
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int()
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.float()
targets = targets.float()
with self.assertRaisesRegex(RuntimeError, r"targets must be int32 or int64 type"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int()
with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((10000), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((10000))
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device)
blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
...@@ -1133,3 +1232,24 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -1133,3 +1232,24 @@ class FunctionalCPUOnly(TestBaseMixin):
warnings.simplefilter("always") warnings.simplefilter("always")
F.melscale_fbanks(201, 0, 8000, 128, 16000) F.melscale_fbanks(201, 0, 8000, 128, 16000)
assert len(w) == 1 assert len(w) == 1
class FunctionalCUDAOnly(Functional):
@nested_params(
[torch.half, torch.float, torch.double], [torch.int32, torch.int64], [(5, 6), (10, 6)], [(1,), (4,), (5,)]
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 6, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda()
target_lengths_cuda = target_lengths.cuda()
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths)
hyp_path_cuda, hyp_scores_cuda = F.forced_align(
log_probs_cuda, targets_cuda, input_lengths_cuda, target_lengths_cuda
)
self.assertEqual(hyp_path, hyp_path_cuda.cpu())
self.assertEqual(hyp_scores, hyp_scores_cuda.cpu())
...@@ -40,6 +40,7 @@ _BUILD_RNNT = _get_build("BUILD_RNNT", True) ...@@ -40,6 +40,7 @@ _BUILD_RNNT = _get_build("BUILD_RNNT", True)
_USE_FFMPEG = _get_build("USE_FFMPEG", False) _USE_FFMPEG = _get_build("USE_FFMPEG", False)
_USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None) _USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None) _USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None)
_BUILD_ALIGN = _get_build("BUILD_ALIGN", True)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA) _BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA)
_USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info() _USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None) _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
...@@ -118,6 +119,7 @@ class CMakeBuild(build_ext): ...@@ -118,6 +119,7 @@ class CMakeBuild(build_ext):
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}", f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_RIR:BOOL={'ON' if _BUILD_RIR else 'OFF'}", f"-DBUILD_RIR:BOOL={'ON' if _BUILD_RIR else 'OFF'}",
f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}", f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}",
f"-DBUILD_ALIGN:BOOL={'ON' if _BUILD_ALIGN else 'OFF'}",
f"-DBUILD_CUDA_CTC_DECODER:BOOL={'ON' if _BUILD_CUDA_CTC_DECODER else 'OFF'}", f"-DBUILD_CUDA_CTC_DECODER:BOOL={'ON' if _BUILD_CUDA_CTC_DECODER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
......
...@@ -39,6 +39,7 @@ _IS_TORCHAUDIO_EXT_AVAILABLE = is_module_available("torchaudio.lib._torchaudio") ...@@ -39,6 +39,7 @@ _IS_TORCHAUDIO_EXT_AVAILABLE = is_module_available("torchaudio.lib._torchaudio")
# Kaldi or RIR features are found there. # Kaldi or RIR features are found there.
_IS_RIR_AVAILABLE = False _IS_RIR_AVAILABLE = False
_IS_KALDI_AVAILABLE = False _IS_KALDI_AVAILABLE = False
_IS_ALIGN_AVAILABLE = False
if _IS_TORCHAUDIO_EXT_AVAILABLE: if _IS_TORCHAUDIO_EXT_AVAILABLE:
_load_lib("libtorchaudio") _load_lib("libtorchaudio")
...@@ -47,6 +48,7 @@ if _IS_TORCHAUDIO_EXT_AVAILABLE: ...@@ -47,6 +48,7 @@ if _IS_TORCHAUDIO_EXT_AVAILABLE:
_check_cuda_version() _check_cuda_version()
_IS_RIR_AVAILABLE = torchaudio.lib._torchaudio.is_rir_available() _IS_RIR_AVAILABLE = torchaudio.lib._torchaudio.is_rir_available()
_IS_KALDI_AVAILABLE = torchaudio.lib._torchaudio.is_kaldi_available() _IS_KALDI_AVAILABLE = torchaudio.lib._torchaudio.is_kaldi_available()
_IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available()
# Similar to libtorchaudio, sox-related features should be importable when present. # Similar to libtorchaudio, sox-related features should be importable when present.
...@@ -99,3 +101,12 @@ fail_if_no_rir = ( ...@@ -99,3 +101,12 @@ fail_if_no_rir = (
"requires RIR extension, but TorchAudio is not compiled with it. Please build TorchAudio with RIR support." "requires RIR extension, but TorchAudio is not compiled with it. Please build TorchAudio with RIR support."
) )
) )
fail_if_no_align = (
no_op
if _IS_ALIGN_AVAILABLE
else fail_with_message(
"Requires alignment extension, but TorchAudio is not compiled with it. \
Please build TorchAudio with alignment support."
)
)
...@@ -46,6 +46,23 @@ if(BUILD_RIR) ...@@ -46,6 +46,23 @@ if(BUILD_RIR)
list(APPEND compile_definitions INCLUDE_RIR) list(APPEND compile_definitions INCLUDE_RIR)
endif() endif()
if(BUILD_ALIGN)
list(
APPEND
sources
forced_align/compute.cpp
forced_align/cpu/compute.cpp
)
list(APPEND compile_definitions INCLUDE_ALIGN)
if (USE_CUDA)
list(
APPEND
sources
forced_align/gpu/compute.cu
)
endif()
endif()
if(USE_CUDA) if(USE_CUDA)
list( list(
APPEND APPEND
......
#include <torch/script.h>
#include <torchaudio/csrc/forced_align/compute.h>
std::tuple<torch::Tensor, torch::Tensor> forced_align(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::forced_align", "")
.typed<decltype(forced_align)>();
return op.call(logProbs, targets, inputLengths, targetLengths, blank);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
}
#pragma once
#include <torch/script.h>
std::tuple<torch::Tensor, torch::Tensor> forced_align(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank);
#include <torch/script.h>
#include <torch/torch.h>
using namespace std;
namespace torchaudio {
namespace alignment {
namespace cpu {
// Inspired from
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
template <typename scalar_t, at::ScalarType target_scalar_type>
void forced_align_impl(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const int64_t blank,
torch::Tensor& paths) {
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto T = logProbs.size(0);
const auto L = targets.size(0);
const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty(
{2, S},
torch::TensorOptions()
.device(logProbs.device())
.dtype(logProbs.dtype()))
.fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 2>();
auto targets_a = targets.accessor<target_t, 1>();
auto paths_a = paths.accessor<target_t, 1>();
auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>();
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[i] == targets_a[i - 1]) {
++R;
}
}
TORCH_CHECK(
T >= L + R,
"targets length is too long for CTC. Found targets length: ",
T,
", log_probs length: ",
L,
", and number of repeats: ",
R);
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = S == 1 ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[0][i] = logProbs_a[0][labelIdx];
}
for (auto t = 1; t < T; t++) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
targets_a[start / 2] != targets_a[start / 2 + 1]) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L &&
targets_a[end / 2 - 1] != targets_a[end / 2]) {
end = end + 1;
}
end = end + 1;
}
auto startloop = start;
auto curIdxOffset = t % 2;
auto prevIdxOffset = (t - 1) % 2;
for (auto j = 0; j < S; ++j) {
alphas_a[curIdxOffset][j] = -std::numeric_limits<scalar_t>::infinity();
}
if (start == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
backPtr_a[t][S] = 0;
startloop += 1;
}
for (auto i = startloop; i < end; i++) {
auto x0 = alphas_a[prevIdxOffset][i];
auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity();
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
// In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
if (x2 > x1 && x2 > x0) {
result = x2;
backPtr_a[t][i] = 2;
} else if (x1 > x0 && x1 > x2) {
result = x1;
backPtr_a[t][i] = 1;
} else {
result = x0;
backPtr_a[t][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
}
}
auto idx1 = (T - 1) % 2;
auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2;
// path stores the token index for each time step after force alignment.
auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx];
}
}
std::tuple<torch::Tensor, torch::Tensor> compute(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank) {
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
TORCH_CHECK(
logProbs.device() == targets.device(),
"log_probs and targets need to be on the same device");
TORCH_CHECK(
logProbs.dtype() == torch::kFloat64 ||
logProbs.dtype() == torch::kFloat32 ||
logProbs.dtype() == torch::kFloat16,
"log_probs must be float64, float32 or float16 (half) type");
TORCH_CHECK(
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
"targets must be int32 or int64 type");
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK(
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch");
const auto T = logProbs.size(0);
auto paths = torch::zeros(
{T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
forced_align_impl<scalar_t, torch::kInt64>(
logProbs, targets, blank, paths);
} else {
forced_align_impl<scalar_t, torch::kInt32>(
logProbs, targets, blank, paths);
}
});
return std::make_tuple(
paths,
logProbs.index(
{torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths}));
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &compute);
}
} // namespace cpu
} // namespace alignment
} // namespace torchaudio
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <limits.h>
#include <torch/torch.h>
#include <cub/cub.cuh>
using namespace torch::indexing;
namespace {
constexpr int kNumThreads =
1024; // Number of threads to run CUDA kernel in parallel.
constexpr int kBackPtrBufferSize =
100; // Buffer size of backPtr on GPU. The data is transferred to CPU once
// the buffer reaches this max size.
} // anonymous namespace
namespace torchaudio {
namespace alignment {
namespace gpu {
template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel(
const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
logProbs_a,
const torch::PackedTensorAccessor32<target_t, 1, torch::RestrictPtrTraits>
targets_a,
const int T,
const int L,
const int N,
const int R,
const int t,
const int64_t blank,
int start,
int end,
int backPtrBufferLen,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
alphas_a,
torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits>
backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
int S = 2 * L + 1;
int curIdxOffset = (t % 2); // current time step frame for alpha
int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha
// reset alpha and backPtrBuffer values
for (unsigned int i = threadIdx.x; i < S; i += blockDim.x) {
alphas_a[curIdxOffset][i] = kNegInfinity;
backPtrBuffer_a[backPtrBufferLen][i] = -1;
}
// This sync could potentially be removed through careful indexing inside each
// thread for the above for loop. But this is okay for now.
__syncthreads();
if (t == 0) {
for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[0][labelIdx];
}
return;
}
using BlockReduce = cub::BlockReduce<scalar_t, kNumThreads>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ scalar_t maxValue;
scalar_t threadMax;
int startloop = start;
threadMax = kNegInfinity;
if (start == 0 && threadIdx.x == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
threadMax = max(threadMax, alphas_a[curIdxOffset][0]);
backPtrBuffer_a[backPtrBufferLen][0] = 0;
}
if (start == 0) {
startloop += 1;
}
for (unsigned int i = startloop + threadIdx.x; i < end; i += blockDim.x) {
scalar_t x0 = alphas_a[prevIdxOffset][i];
scalar_t x1 = alphas_a[prevIdxOffset][i - 1];
scalar_t x2 = kNegInfinity;
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
if (x2 > x1 && x2 > x0) {
result = x2;
backPtrBuffer_a[backPtrBufferLen][i] = 2;
} else if (x1 > x0 && x1 > x2) {
result = x1;
backPtrBuffer_a[backPtrBufferLen][i] = 1;
} else {
result = x0;
backPtrBuffer_a[backPtrBufferLen][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
threadMax = max(threadMax, alphas_a[curIdxOffset][i]);
}
scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max());
if (threadIdx.x == 0) {
maxValue = maxResult;
}
__syncthreads();
// normalize alpha values so that they don't overflow for large T
for (unsigned int i = threadIdx.x; i < S; i += blockDim.x) {
alphas_a[curIdxOffset][i] -= maxValue;
}
}
template <typename scalar_t, torch::ScalarType target_scalar_type>
void forced_align_impl(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const int64_t blank,
torch::Tensor& paths) {
auto defaultStream = at::cuda::getCurrentCUDAStream();
auto cpuDataTranferStream = at::cuda::getStreamFromPool();
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
auto paths_a = paths.accessor<target_t, 1>();
const int T = logProbs.size(0); // num frames
const int N = logProbs.size(1); // alphabet size
const int L = targets.size(0); // label length
const int S = 2 * L + 1;
auto targetsCpu = targets.to(torch::kCPU);
// backPtrBuffer stores the index offset fthe best path at current position
// We copy the values to CPU after running every kBackPtrBufferSize of
// frames.
torch::Tensor backPtrBuffer =
torch::empty(
{min(kBackPtrBufferSize, T), S},
torch::TensorOptions().dtype(torch::kInt8).device(logProbs.device()))
.contiguous()
.fill_(-1);
torch::Tensor backPtrCpu =
torch::empty(
{T, S},
torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU))
.contiguous()
.fill_(-1);
// we store only two time frames for alphas
// alphas for compute current timeframe can be computed only from previous
// time frame.
torch::Tensor alphas = torch::empty(
{2, S},
torch::TensorOptions()
.dtype(logProbs.dtype())
.device(logProbs.device()))
.fill_(kNegInfinity);
// CPU accessors
auto targetsCpu_a = targetsCpu.accessor<target_t, 1>();
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
// count the number of repeats in label
int R = 0;
for (int i = 1; i < L; ++i) {
if (targetsCpu_a[i] == targetsCpu_a[i - 1]) {
++R;
}
}
TORCH_CHECK(
T >= L + R,
"targets length is too long for CTC. Found targets length: ",
T,
", log_probs length: ",
L,
", and number of repeats: ",
R);
int start = (T - (L + R)) > 0 ? 0 : 1;
int end = (S == 1) ? 1 : 2;
int backPtrBufferLen = 0;
torch::Tensor bufferCopy;
for (int t = 0; t < T; ++t) {
if (t > 0) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
(targetsCpu_a[start / 2] != targetsCpu_a[start / 2 + 1])) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if ((end % 2 == 0) && (end < 2 * L) &&
(targetsCpu_a[end / 2 - 1] != targetsCpu_a[end / 2])) {
end = end + 1;
}
end = end + 1;
}
}
falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>(
logProbs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 1, torch::RestrictPtrTraits>(),
T,
L,
N,
R,
t,
blank,
start,
end,
backPtrBufferLen,
alphas.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
backPtrBuffer
.packed_accessor32<int8_t, 2, torch::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
++backPtrBufferLen;
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
cpuDataTranferStream.synchronize();
// GPU -> GPU copy
bufferCopy = backPtrBuffer.clone().contiguous();
defaultStream.synchronize();
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
// Copy ASYNC from GPU to CPU
int64_t offset =
static_cast<int64_t>(t + 1 - backPtrBufferLen) * S * sizeof(int8_t);
C10_CUDA_CHECK(cudaMemcpyAsync(
static_cast<int8_t*>(backPtrCpu.data_ptr()) + offset,
bufferCopy.data_ptr(),
backPtrBufferLen * S * sizeof(int8_t),
cudaMemcpyDeviceToHost,
cpuDataTranferStream));
at::cuda::setCurrentCUDAStream(defaultStream);
backPtrBufferLen = 0;
}
}
cpuDataTranferStream.synchronize();
torch::Tensor alphasCpu = alphas.to(torch::kCPU);
auto alphasCpu_a = alphasCpu.accessor<scalar_t, 2>();
int curIdxOffset = ((T - 1) % 2);
int ltrIdx =
alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2]
? S - 1
: S - 2;
int indexScores = 0;
for (int t = T - 1; t >= 0; --t) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targetsCpu_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
++indexScores;
ltrIdx -= backPtrCpu_a[t][ltrIdx];
}
}
std::tuple<torch::Tensor, torch::Tensor> compute(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank) {
TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor");
TORCH_CHECK(targets.is_cuda(), "targets must be a CUDA tensor");
TORCH_CHECK(
logProbs.device() == targets.device(),
"log_probs and targets need to be on the same device");
TORCH_CHECK(
logProbs.dtype() == torch::kFloat64 ||
logProbs.dtype() == torch::kFloat32 ||
logProbs.dtype() == torch::kFloat16,
"log_probs must be float64, float32 or float16 (half) type");
TORCH_CHECK(
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
"targets must be int32 or int64 type");
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK(
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch");
auto T = logProbs.size(0); // num frames
auto paths = torch::zeros(
{T}, torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
forced_align_impl<scalar_t, torch::kInt64>(
logProbs, targets, blank, paths);
} else {
forced_align_impl<scalar_t, torch::kInt32>(
logProbs, targets, blank, paths);
}
});
return std::make_tuple(
paths.to(logProbs.device()),
logProbs.index(
{torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths}));
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("forced_align", &compute);
}
} // namespace gpu
} // namespace alignment
} // namespace torchaudio
...@@ -7,6 +7,7 @@ namespace { ...@@ -7,6 +7,7 @@ namespace {
PYBIND11_MODULE(_torchaudio, m) { PYBIND11_MODULE(_torchaudio, m) {
m.def("is_kaldi_available", &is_kaldi_available, ""); m.def("is_kaldi_available", &is_kaldi_available, "");
m.def("is_rir_available", &is_rir_available, ""); m.def("is_rir_available", &is_rir_available, "");
m.def("is_align_available", &is_align_available, "");
m.def("cuda_version", &cuda_version, ""); m.def("cuda_version", &cuda_version, "");
} }
......
...@@ -23,6 +23,14 @@ bool is_rir_available() { ...@@ -23,6 +23,14 @@ bool is_rir_available() {
#endif #endif
} }
bool is_align_available() {
#ifdef INCLUDE_ALIGN
return true;
#else
return false;
#endif
}
c10::optional<int64_t> cuda_version() { c10::optional<int64_t> cuda_version() {
#ifdef USE_CUDA #ifdef USE_CUDA
return CUDA_VERSION; return CUDA_VERSION;
......
...@@ -4,5 +4,6 @@ ...@@ -4,5 +4,6 @@
namespace torchaudio { namespace torchaudio {
bool is_kaldi_available(); bool is_kaldi_available();
bool is_rir_available(); bool is_rir_available();
bool is_align_available();
c10::optional<int64_t> cuda_version(); c10::optional<int64_t> cuda_version();
} // namespace torchaudio } // namespace torchaudio
...@@ -36,6 +36,7 @@ from .functional import ( ...@@ -36,6 +36,7 @@ from .functional import (
detect_pitch_frequency, detect_pitch_frequency,
edit_distance, edit_distance,
fftconvolve, fftconvolve,
forced_align,
griffinlim, griffinlim,
inverse_spectrogram, inverse_spectrogram,
linear_fbanks, linear_fbanks,
...@@ -94,6 +95,7 @@ __all__ = [ ...@@ -94,6 +95,7 @@ __all__ = [
"equalizer_biquad", "equalizer_biquad",
"filtfilt", "filtfilt",
"flanger", "flanger",
"forced_align",
"gain", "gain",
"highpass_biquad", "highpass_biquad",
"lfilter", "lfilter",
......
...@@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union ...@@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torchaudio._extension import fail_if_no_align
from .filtering import highpass_biquad, treble_biquad from .filtering import highpass_biquad, treble_biquad
...@@ -51,6 +52,7 @@ __all__ = [ ...@@ -51,6 +52,7 @@ __all__ = [
"speed", "speed",
"preemphasis", "preemphasis",
"deemphasis", "deemphasis",
"forced_align",
] ]
...@@ -2596,3 +2598,43 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor: ...@@ -2596,3 +2598,43 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device) a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device) b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs) return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
@fail_if_no_align
def forced_align(
log_probs: torch.Tensor,
targets: torch.Tensor,
input_lengths: torch.Tensor,
target_lengths: torch.Tensor,
blank: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes forced alignment given the emissions from a CTC-trained model and a target label.
Args:
log_probs (torch.Tensor): log probability of CTC emission output.
Tensor with dimensions `(T, C)`. where `T` is the input length,
vocabulary is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor with dimension `(L,)`,
where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). Tensor with dimension `()`.
target_lengths (torch.Tensor): Lengths of the targets. Tensor with dimension `()`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(torch.Tensor, torch.Tensor):
torch.Tensor: Label for each time step in the alignemnt path computed using forced alignment.
torch.Tensor: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
.. math::
L_{\\text{log_probs}} \\ge L_{\\text{label}} + N_{\\text{repeat}}
where :math:`N_{\\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment