Commit 8a893fb3 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix CPU kernel of forced_align function (#3354)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3354

when start ==0, the first item instead of Sth item of t row in backPtr_a should be 0.

Reviewed By: xiaohui-zhang

Differential Revision: D46059971

fbshipit-source-id: 89933134878513034eae033764b19f8562f24cb8
parent 011f7f3d
......@@ -1236,11 +1236,14 @@ class FunctionalCPUOnly(TestBaseMixin):
class FunctionalCUDAOnly(Functional):
@nested_params(
[torch.half, torch.float, torch.double], [torch.int32, torch.int64], [(5, 6), (10, 6)], [(1,), (4,), (5,)]
[torch.half, torch.float, torch.double],
[torch.int32, torch.int64],
[(50, 100), (100, 100)],
[(10,), (40,), (45,)],
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 6, targets_shape, dtype=targets_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
log_probs_cuda = log_probs.cuda()
......
......@@ -47,7 +47,7 @@ void forced_align_impl(
", and number of repeats: ",
R);
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = S == 1 ? 1 : 2;
auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[0][i] = logProbs_a[0][labelIdx];
......@@ -76,7 +76,7 @@ void forced_align_impl(
if (start == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
backPtr_a[t][S] = 0;
backPtr_a[t][0] = 0;
startloop += 1;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment