Commit 5f17d81c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update forced_align method to only support batch Tensors (#3365)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3365

Current design of forced_align accept 2D Tensor for `log_probs` and 1D Tensor for `targets`. To make the API simple, the PR make changes to only support batch Tensors (3D Tensor for `log_probs` and 2D Tensor for `targets`).

Reviewed By: vineelpratap

Differential Revision: D46126226

fbshipit-source-id: 42cb52b19d91bbff7dc040ccf60350545d75b3a2
parent c076d1a8
...@@ -1116,55 +1116,60 @@ class Functional(TestBaseMixin): ...@@ -1116,55 +1116,60 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
[ [
([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32), ([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32),
([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32), ([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32),
([3, 3, 3], [3, 5, 3, 5, 3], torch.int64), ([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64),
([0, 1, 2], [0, 1, 1, 1, 2], torch.int64), ([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64),
] ]
) )
def test_forced_align(self, targets, ref_path, targets_dtype): def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor( emission = torch.tensor(
[ [
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], [0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107], [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
], ],
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
blank = 5 blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device) ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor( ref_scores = torch.tensor(
[torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])], [torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])],
dtype=emission.dtype, dtype=emission.dtype,
device=self.device, device=self.device,
) ).unsqueeze(0)
log_probs = torch.log(emission) log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device) targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0])) input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor((targets.shape[0])) target_lengths = torch.tensor([targets.shape[1]], device=self.device)
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path) self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores) self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)]) @parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype): def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device) targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device)
blank = 5 blank = 5
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device) input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device) target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"): with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device) targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"): with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int() log_probs = log_probs.int()
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device) targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"): with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
...@@ -1175,40 +1180,42 @@ class Functional(TestBaseMixin): ...@@ -1175,40 +1180,42 @@ class Functional(TestBaseMixin):
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int() targets = targets.int()
with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"): with self.assertRaisesRegex(
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device) targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"): with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device) targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3,), device=self.device) input_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"): with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device) input_lengths = torch.tensor([log_probs.shape[0]], device=self.device)
target_lengths = torch.randint(1, 5, (3,), device=self.device) target_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"): with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((10000), device=self.device) input_lengths = torch.tensor([10000], device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device) target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"): with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0])) input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor((10000)) target_lengths = torch.tensor([10000], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"): with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device) targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device) log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"): with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device) targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device)
blank = 10000 blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"): with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
...@@ -1238,14 +1245,14 @@ class FunctionalCUDAOnly(TestBaseMixin): ...@@ -1238,14 +1245,14 @@ class FunctionalCUDAOnly(TestBaseMixin):
@nested_params( @nested_params(
[torch.half, torch.float, torch.double], [torch.half, torch.float, torch.double],
[torch.int32, torch.int64], [torch.int32, torch.int64],
[(50, 100), (100, 100)], [(1, 50, 100), (1, 100, 100)],
[(10,), (40,), (45,)], [(1, 10), (1, 40), (1, 45)],
) )
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape): def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device) log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device) targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device) input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device) target_lengths = torch.tensor([targets.shape[1]], device=self.device)
log_probs_cuda = log_probs.cuda() log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda() targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda() input_lengths_cuda = input_lengths.cuda()
......
...@@ -17,8 +17,10 @@ void forced_align_impl( ...@@ -17,8 +17,10 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std:: using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type; conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto T = logProbs.size(0); const auto batchIndex =
const auto L = targets.size(0); 0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
const auto L = targets.size(1);
const auto S = 2 * L + 1; const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty( torch::Tensor alphas = torch::empty(
{2, S}, {2, S},
...@@ -27,14 +29,14 @@ void forced_align_impl( ...@@ -27,14 +29,14 @@ void forced_align_impl(
.dtype(logProbs.dtype())) .dtype(logProbs.dtype()))
.fill_(kNegInfinity); .fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1); torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 2>(); auto logProbs_a = logProbs.accessor<scalar_t, 3>();
auto targets_a = targets.accessor<target_t, 1>(); auto targets_a = targets.accessor<target_t, 2>();
auto paths_a = paths.accessor<target_t, 1>(); auto paths_a = paths.accessor<target_t, 2>();
auto alphas_a = alphas.accessor<scalar_t, 2>(); auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>(); auto backPtr_a = backPtr.accessor<int8_t, 2>();
auto R = 0; auto R = 0;
for (auto i = 1; i < L; i++) { for (auto i = 1; i < L; i++) {
if (targets_a[i] == targets_a[i - 1]) { if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
++R; ++R;
} }
} }
...@@ -49,20 +51,22 @@ void forced_align_impl( ...@@ -49,20 +51,22 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1; auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2; auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) { for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2]; auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[0][i] = logProbs_a[0][labelIdx]; alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
} }
for (auto t = 1; t < T; t++) { for (auto t = 1; t < T; t++) {
if (T - t <= L + R) { if (T - t <= L + R) {
if ((start % 2 == 1) && if ((start % 2 == 1) &&
targets_a[start / 2] != targets_a[start / 2 + 1]) { targets_a[batchIndex][start / 2] !=
targets_a[batchIndex][start / 2 + 1]) {
start = start + 1; start = start + 1;
} }
start = start + 1; start = start + 1;
} }
if (t <= L + R) { if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L && if (end % 2 == 0 && end < 2 * L &&
targets_a[end / 2 - 1] != targets_a[end / 2]) { targets_a[batchIndex][end / 2 - 1] !=
targets_a[batchIndex][end / 2]) {
end = end + 1; end = end + 1;
} }
end = end + 1; end = end + 1;
...@@ -75,7 +79,7 @@ void forced_align_impl( ...@@ -75,7 +79,7 @@ void forced_align_impl(
} }
if (start == 0) { if (start == 0) {
alphas_a[curIdxOffset][0] = alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank]; alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
backPtr_a[t][0] = 0; backPtr_a[t][0] = 0;
startloop += 1; startloop += 1;
} }
...@@ -85,13 +89,14 @@ void forced_align_impl( ...@@ -85,13 +89,14 @@ void forced_align_impl(
auto x1 = alphas_a[prevIdxOffset][i - 1]; auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity(); auto x2 = -std::numeric_limits<scalar_t>::infinity();
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2]; auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
// In CTC, the optimal path may optionally chose to skip a blank label. // In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not // x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter // currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) { if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2]; x2 = alphas_a[prevIdxOffset][i - 2];
} }
scalar_t result = 0.0; scalar_t result = 0.0;
...@@ -105,7 +110,7 @@ void forced_align_impl( ...@@ -105,7 +110,7 @@ void forced_align_impl(
result = x0; result = x0;
backPtr_a[t][i] = 0; backPtr_a[t][i] = 0;
} }
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx]; alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
} }
} }
auto idx1 = (T - 1) % 2; auto idx1 = (T - 1) % 2;
...@@ -113,8 +118,8 @@ void forced_align_impl( ...@@ -113,8 +118,8 @@ void forced_align_impl(
// path stores the token index for each time step after force alignment. // path stores the token index for each time step after force alignment.
auto indexScores = 0; auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) { for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ltrIdx / 2]; auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[t] = lbl_idx; paths_a[batchIndex][t] = lbl_idx;
++indexScores; ++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx]; ltrIdx -= backPtr_a[t][ltrIdx];
} }
...@@ -142,30 +147,35 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -142,30 +147,35 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() != 3, logProbs.dim() == 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.") "log_probs must be 3-D (batch_size, input length, num classes)");
TORCH_CHECK( TORCH_CHECK(
targets.dim() != 2, targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)"); inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)"); TORCH_CHECK(
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D"); targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D"); TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK( TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1), blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)"); "blank must be within [0, num classes)");
TORCH_CHECK( TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(), logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(), targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch"); "target length mismatch");
const auto T = logProbs.size(0); const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros( auto paths = torch::zeros(
{T}, {B, T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] { logProbs.scalar_type(), "forced_align_impl", [&] {
...@@ -180,9 +190,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -180,9 +190,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple( return std::make_tuple(
paths, paths,
logProbs.index( logProbs.index(
{torch::linspace( {torch::indexing::Slice(),
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())), 0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths})); paths.index({0})}));
} }
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
......
...@@ -18,9 +18,9 @@ namespace alignment { ...@@ -18,9 +18,9 @@ namespace alignment {
namespace gpu { namespace gpu {
template <typename scalar_t, typename target_t> template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel( __global__ void falign_cuda_step_kernel(
const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
logProbs_a, logProbs_a,
const torch::PackedTensorAccessor32<target_t, 1, torch::RestrictPtrTraits> const torch::PackedTensorAccessor32<target_t, 2, torch::RestrictPtrTraits>
targets_a, targets_a,
const int T, const int T,
const int L, const int L,
...@@ -36,6 +36,8 @@ __global__ void falign_cuda_step_kernel( ...@@ -36,6 +36,8 @@ __global__ void falign_cuda_step_kernel(
torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits>
backPtrBuffer_a) { backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
int S = 2 * L + 1; int S = 2 * L + 1;
int curIdxOffset = (t % 2); // current time step frame for alpha int curIdxOffset = (t % 2); // current time step frame for alpha
int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha
...@@ -49,8 +51,8 @@ __global__ void falign_cuda_step_kernel( ...@@ -49,8 +51,8 @@ __global__ void falign_cuda_step_kernel(
__syncthreads(); __syncthreads();
if (t == 0) { if (t == 0) {
for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) { for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2]; int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[0][labelIdx]; alphas_a[curIdxOffset][i] = logProbs_a[batchIndex][0][labelIdx];
} }
return; return;
} }
...@@ -62,7 +64,7 @@ __global__ void falign_cuda_step_kernel( ...@@ -62,7 +64,7 @@ __global__ void falign_cuda_step_kernel(
threadMax = kNegInfinity; threadMax = kNegInfinity;
if (start == 0 && threadIdx.x == 0) { if (start == 0 && threadIdx.x == 0) {
alphas_a[curIdxOffset][0] = alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank]; alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
threadMax = max(threadMax, alphas_a[curIdxOffset][0]); threadMax = max(threadMax, alphas_a[curIdxOffset][0]);
backPtrBuffer_a[backPtrBufferLen][0] = 0; backPtrBuffer_a[backPtrBufferLen][0] = 0;
} }
...@@ -73,8 +75,9 @@ __global__ void falign_cuda_step_kernel( ...@@ -73,8 +75,9 @@ __global__ void falign_cuda_step_kernel(
scalar_t x0 = alphas_a[prevIdxOffset][i]; scalar_t x0 = alphas_a[prevIdxOffset][i];
scalar_t x1 = alphas_a[prevIdxOffset][i - 1]; scalar_t x1 = alphas_a[prevIdxOffset][i - 1];
scalar_t x2 = kNegInfinity; scalar_t x2 = kNegInfinity;
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2]; int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) { if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2]; x2 = alphas_a[prevIdxOffset][i - 2];
} }
scalar_t result = 0.0; scalar_t result = 0.0;
...@@ -88,7 +91,7 @@ __global__ void falign_cuda_step_kernel( ...@@ -88,7 +91,7 @@ __global__ void falign_cuda_step_kernel(
result = x0; result = x0;
backPtrBuffer_a[backPtrBufferLen][i] = 0; backPtrBuffer_a[backPtrBufferLen][i] = 0;
} }
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx]; alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
threadMax = max(threadMax, alphas_a[curIdxOffset][i]); threadMax = max(threadMax, alphas_a[curIdxOffset][i]);
} }
scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max()); scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max());
...@@ -113,10 +116,12 @@ void forced_align_impl( ...@@ -113,10 +116,12 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std:: using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type; conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
auto paths_a = paths.accessor<target_t, 1>(); auto paths_a = paths.accessor<target_t, 2>();
const int T = logProbs.size(0); // num frames const int batchIndex =
const int N = logProbs.size(1); // alphabet size 0; // TODO: support batch version and use the real batch index
const int L = targets.size(0); // label length const int T = logProbs.size(1); // num frames
const int N = logProbs.size(2); // alphabet size
const int L = targets.size(1); // label length
const int S = 2 * L + 1; const int S = 2 * L + 1;
auto targetsCpu = targets.to(torch::kCPU); auto targetsCpu = targets.to(torch::kCPU);
// backPtrBuffer stores the index offset fthe best path at current position // backPtrBuffer stores the index offset fthe best path at current position
...@@ -144,12 +149,12 @@ void forced_align_impl( ...@@ -144,12 +149,12 @@ void forced_align_impl(
.device(logProbs.device())) .device(logProbs.device()))
.fill_(kNegInfinity); .fill_(kNegInfinity);
// CPU accessors // CPU accessors
auto targetsCpu_a = targetsCpu.accessor<target_t, 1>(); auto targetsCpu_a = targetsCpu.accessor<target_t, 2>();
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>(); auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
// count the number of repeats in label // count the number of repeats in label
int R = 0; int R = 0;
for (int i = 1; i < L; ++i) { for (int i = 1; i < L; ++i) {
if (targetsCpu_a[i] == targetsCpu_a[i - 1]) { if (targetsCpu_a[batchIndex][i] == targetsCpu_a[batchIndex][i - 1]) {
++R; ++R;
} }
} }
...@@ -169,14 +174,16 @@ void forced_align_impl( ...@@ -169,14 +174,16 @@ void forced_align_impl(
if (t > 0) { if (t > 0) {
if (T - t <= L + R) { if (T - t <= L + R) {
if ((start % 2 == 1) && if ((start % 2 == 1) &&
(targetsCpu_a[start / 2] != targetsCpu_a[start / 2 + 1])) { (targetsCpu_a[batchIndex][start / 2] !=
targetsCpu_a[batchIndex][start / 2 + 1])) {
start = start + 1; start = start + 1;
} }
start = start + 1; start = start + 1;
} }
if (t <= L + R) { if (t <= L + R) {
if ((end % 2 == 0) && (end < 2 * L) && if ((end % 2 == 0) && (end < 2 * L) &&
(targetsCpu_a[end / 2 - 1] != targetsCpu_a[end / 2])) { (targetsCpu_a[batchIndex][end / 2 - 1] !=
targetsCpu_a[batchIndex][end / 2])) {
end = end + 1; end = end + 1;
} }
end = end + 1; end = end + 1;
...@@ -184,8 +191,8 @@ void forced_align_impl( ...@@ -184,8 +191,8 @@ void forced_align_impl(
} }
falign_cuda_step_kernel<scalar_t, target_t> falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>( <<<1, kNumThreads, 0, defaultStream>>>(
logProbs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(), logProbs.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 1, torch::RestrictPtrTraits>(), targets.packed_accessor32<target_t, 2, torch::RestrictPtrTraits>(),
T, T,
L, L,
N, N,
...@@ -229,8 +236,9 @@ void forced_align_impl( ...@@ -229,8 +236,9 @@ void forced_align_impl(
: S - 2; : S - 2;
int indexScores = 0; int indexScores = 0;
for (int t = T - 1; t >= 0; --t) { for (int t = T - 1; t >= 0; --t) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targetsCpu_a[ltrIdx / 2]; auto lbl_idx =
paths_a[t] = lbl_idx; ltrIdx % 2 == 0 ? blank : targetsCpu_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
++indexScores; ++indexScores;
ltrIdx -= backPtrCpu_a[t][ltrIdx]; ltrIdx -= backPtrCpu_a[t][ltrIdx];
} }
...@@ -258,30 +266,36 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -258,30 +266,36 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() != 3, logProbs.dim() == 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.") "log_probs must be 3-D (batch_size, input length, num classes)");
TORCH_CHECK( TORCH_CHECK(
targets.dim() != 2, targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)"); inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)"); TORCH_CHECK(
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D"); targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D"); TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK( TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1), blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)"); "blank must be within [0, num classes)");
TORCH_CHECK( TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(), logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(), targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch"); "target length mismatch");
auto T = logProbs.size(0); // num frames auto B = logProbs.size(0);
auto T = logProbs.size(1); // num frames
auto paths = torch::zeros( auto paths = torch::zeros(
{T}, torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype())); {B, T},
torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] { logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) { if (targets.scalar_type() == torch::kInt64) {
...@@ -295,9 +309,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -295,9 +309,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple( return std::make_tuple(
paths.to(logProbs.device()), paths.to(logProbs.device()),
logProbs.index( logProbs.index(
{torch::linspace( {torch::indexing::Slice(),
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())), 0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths})); paths.index({0})}));
} }
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
......
...@@ -2508,12 +2508,12 @@ def forced_align( ...@@ -2508,12 +2508,12 @@ def forced_align(
Args: Args:
log_probs (torch.Tensor): log probability of CTC emission output. log_probs (torch.Tensor): log probability of CTC emission output.
Tensor of shape `(T, C)`. where `T` is the input length, Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank. `C` is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor of shape `(L,)`, targets (torch.Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length. where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 0-D Tensor (scalar). input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (torch.Tensor): Lengths of the targets. 0-D Tensor (scalar). target_lengths (torch.Tensor): Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns: Returns:
...@@ -2531,6 +2531,9 @@ def forced_align( ...@@ -2531,6 +2531,9 @@ def forced_align(
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens. where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`. For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size``==1.
""" """
if blank in targets: if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.") raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment