"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "88429853b932471cfdfefd259855d6b8a23aa7c3"
Commit bbc13b9a authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Revert D46126226: Update forced_align method to only support batch Tensors

Differential Revision:
D46126226

Original commit changeset: 42cb52b19d91

Original Phabricator Diff: D46126226

fbshipit-source-id: 372b2526d9e196e37e014f1556bf117d29bb1ac6
parent 5f17d81c
...@@ -1116,60 +1116,55 @@ class Functional(TestBaseMixin): ...@@ -1116,60 +1116,55 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
[ [
([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32), ([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32),
([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32), ([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32),
([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64), ([3, 3, 3], [3, 5, 3, 5, 3], torch.int64),
([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64), ([0, 1, 2], [0, 1, 1, 1, 2], torch.int64),
] ]
) )
def test_forced_align(self, targets, ref_path, targets_dtype): def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor( emission = torch.tensor(
[ [
[ [0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
], ],
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
blank = 5 blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device) ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor( ref_scores = torch.tensor(
[torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])], [torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])],
dtype=emission.dtype, dtype=emission.dtype,
device=self.device, device=self.device,
).unsqueeze(0) )
log_probs = torch.log(emission) log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device) targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device) input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor([targets.shape[1]], device=self.device) target_lengths = torch.tensor((targets.shape[0]))
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path) self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores) self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)]) @parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype): def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device) targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device)
blank = 5 blank = 5
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device) input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device) target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"): with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device) targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"): with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int() log_probs = log_probs.int()
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device) targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"): with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
...@@ -1180,42 +1175,40 @@ class Functional(TestBaseMixin): ...@@ -1180,42 +1175,40 @@ class Functional(TestBaseMixin):
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int() targets = targets.int()
with self.assertRaisesRegex( with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"):
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device) targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device) log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"): with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device) targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3, 5), device=self.device) input_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"): with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[0]], device=self.device) input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.randint(1, 5, (3, 5), device=self.device) target_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"): with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([10000], device=self.device) input_lengths = torch.tensor((10000), device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device) target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"): with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device) input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor([10000], device=self.device) target_lengths = torch.tensor((10000))
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"): with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device) targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device) log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"): with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device) targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device)
blank = 10000 blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"): with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank) hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
...@@ -1245,14 +1238,14 @@ class FunctionalCUDAOnly(TestBaseMixin): ...@@ -1245,14 +1238,14 @@ class FunctionalCUDAOnly(TestBaseMixin):
@nested_params( @nested_params(
[torch.half, torch.float, torch.double], [torch.half, torch.float, torch.double],
[torch.int32, torch.int64], [torch.int32, torch.int64],
[(1, 50, 100), (1, 100, 100)], [(50, 100), (100, 100)],
[(1, 10), (1, 40), (1, 45)], [(10,), (40,), (45,)],
) )
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape): def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device) log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device) targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device) input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device) target_lengths = torch.tensor((targets.shape[0]), device=self.device)
log_probs_cuda = log_probs.cuda() log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda() targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda() input_lengths_cuda = input_lengths.cuda()
......
...@@ -17,10 +17,8 @@ void forced_align_impl( ...@@ -17,10 +17,8 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std:: using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type; conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto batchIndex = const auto T = logProbs.size(0);
0; // TODO: support batch version and use the real batch index const auto L = targets.size(0);
const auto T = logProbs.size(1);
const auto L = targets.size(1);
const auto S = 2 * L + 1; const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty( torch::Tensor alphas = torch::empty(
{2, S}, {2, S},
...@@ -29,14 +27,14 @@ void forced_align_impl( ...@@ -29,14 +27,14 @@ void forced_align_impl(
.dtype(logProbs.dtype())) .dtype(logProbs.dtype()))
.fill_(kNegInfinity); .fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1); torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 3>(); auto logProbs_a = logProbs.accessor<scalar_t, 2>();
auto targets_a = targets.accessor<target_t, 2>(); auto targets_a = targets.accessor<target_t, 1>();
auto paths_a = paths.accessor<target_t, 2>(); auto paths_a = paths.accessor<target_t, 1>();
auto alphas_a = alphas.accessor<scalar_t, 2>(); auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>(); auto backPtr_a = backPtr.accessor<int8_t, 2>();
auto R = 0; auto R = 0;
for (auto i = 1; i < L; i++) { for (auto i = 1; i < L; i++) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) { if (targets_a[i] == targets_a[i - 1]) {
++R; ++R;
} }
} }
...@@ -51,22 +49,20 @@ void forced_align_impl( ...@@ -51,22 +49,20 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1; auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2; auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) { for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx]; alphas_a[0][i] = logProbs_a[0][labelIdx];
} }
for (auto t = 1; t < T; t++) { for (auto t = 1; t < T; t++) {
if (T - t <= L + R) { if (T - t <= L + R) {
if ((start % 2 == 1) && if ((start % 2 == 1) &&
targets_a[batchIndex][start / 2] != targets_a[start / 2] != targets_a[start / 2 + 1]) {
targets_a[batchIndex][start / 2 + 1]) {
start = start + 1; start = start + 1;
} }
start = start + 1; start = start + 1;
} }
if (t <= L + R) { if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L && if (end % 2 == 0 && end < 2 * L &&
targets_a[batchIndex][end / 2 - 1] != targets_a[end / 2 - 1] != targets_a[end / 2]) {
targets_a[batchIndex][end / 2]) {
end = end + 1; end = end + 1;
} }
end = end + 1; end = end + 1;
...@@ -79,7 +75,7 @@ void forced_align_impl( ...@@ -79,7 +75,7 @@ void forced_align_impl(
} }
if (start == 0) { if (start == 0) {
alphas_a[curIdxOffset][0] = alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank]; alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
backPtr_a[t][0] = 0; backPtr_a[t][0] = 0;
startloop += 1; startloop += 1;
} }
...@@ -89,14 +85,13 @@ void forced_align_impl( ...@@ -89,14 +85,13 @@ void forced_align_impl(
auto x1 = alphas_a[prevIdxOffset][i - 1]; auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity(); auto x2 = -std::numeric_limits<scalar_t>::infinity();
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
// In CTC, the optimal path may optionally chose to skip a blank label. // In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not // x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter // currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 && if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2]; x2 = alphas_a[prevIdxOffset][i - 2];
} }
scalar_t result = 0.0; scalar_t result = 0.0;
...@@ -110,7 +105,7 @@ void forced_align_impl( ...@@ -110,7 +105,7 @@ void forced_align_impl(
result = x0; result = x0;
backPtr_a[t][i] = 0; backPtr_a[t][i] = 0;
} }
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx]; alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
} }
} }
auto idx1 = (T - 1) % 2; auto idx1 = (T - 1) % 2;
...@@ -118,8 +113,8 @@ void forced_align_impl( ...@@ -118,8 +113,8 @@ void forced_align_impl(
// path stores the token index for each time step after force alignment. // path stores the token index for each time step after force alignment.
auto indexScores = 0; auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) { for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2]; auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx; paths_a[t] = lbl_idx;
++indexScores; ++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx]; ltrIdx -= backPtr_a[t][ltrIdx];
} }
...@@ -147,35 +142,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -147,35 +142,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() == 3, logProbs.dim() != 3,
"log_probs must be 3-D (batch_size, input length, num classes)"); "3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK( TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK( TORCH_CHECK(
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK( TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK( TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK( TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1), blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)"); "blank must be within [0, num classes)");
TORCH_CHECK( TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(), logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(), targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch"); "target length mismatch");
const auto B = logProbs.size(0); const auto T = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros( auto paths = torch::zeros(
{B, T}, {T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] { logProbs.scalar_type(), "forced_align_impl", [&] {
...@@ -190,10 +180,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -190,10 +180,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple( return std::make_tuple(
paths, paths,
logProbs.index( logProbs.index(
{torch::indexing::Slice(), {torch::linspace(
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())), 0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths.index({0})})); paths}));
} }
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
......
...@@ -18,9 +18,9 @@ namespace alignment { ...@@ -18,9 +18,9 @@ namespace alignment {
namespace gpu { namespace gpu {
template <typename scalar_t, typename target_t> template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel( __global__ void falign_cuda_step_kernel(
const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
logProbs_a, logProbs_a,
const torch::PackedTensorAccessor32<target_t, 2, torch::RestrictPtrTraits> const torch::PackedTensorAccessor32<target_t, 1, torch::RestrictPtrTraits>
targets_a, targets_a,
const int T, const int T,
const int L, const int L,
...@@ -36,8 +36,6 @@ __global__ void falign_cuda_step_kernel( ...@@ -36,8 +36,6 @@ __global__ void falign_cuda_step_kernel(
torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits>
backPtrBuffer_a) { backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
int S = 2 * L + 1; int S = 2 * L + 1;
int curIdxOffset = (t % 2); // current time step frame for alpha int curIdxOffset = (t % 2); // current time step frame for alpha
int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha
...@@ -51,8 +49,8 @@ __global__ void falign_cuda_step_kernel( ...@@ -51,8 +49,8 @@ __global__ void falign_cuda_step_kernel(
__syncthreads(); __syncthreads();
if (t == 0) { if (t == 0) {
for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) { for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[batchIndex][0][labelIdx]; alphas_a[curIdxOffset][i] = logProbs_a[0][labelIdx];
} }
return; return;
} }
...@@ -64,7 +62,7 @@ __global__ void falign_cuda_step_kernel( ...@@ -64,7 +62,7 @@ __global__ void falign_cuda_step_kernel(
threadMax = kNegInfinity; threadMax = kNegInfinity;
if (start == 0 && threadIdx.x == 0) { if (start == 0 && threadIdx.x == 0) {
alphas_a[curIdxOffset][0] = alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank]; alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
threadMax = max(threadMax, alphas_a[curIdxOffset][0]); threadMax = max(threadMax, alphas_a[curIdxOffset][0]);
backPtrBuffer_a[backPtrBufferLen][0] = 0; backPtrBuffer_a[backPtrBufferLen][0] = 0;
} }
...@@ -75,9 +73,8 @@ __global__ void falign_cuda_step_kernel( ...@@ -75,9 +73,8 @@ __global__ void falign_cuda_step_kernel(
scalar_t x0 = alphas_a[prevIdxOffset][i]; scalar_t x0 = alphas_a[prevIdxOffset][i];
scalar_t x1 = alphas_a[prevIdxOffset][i - 1]; scalar_t x1 = alphas_a[prevIdxOffset][i - 1];
scalar_t x2 = kNegInfinity; scalar_t x2 = kNegInfinity;
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
if (i % 2 != 0 && i != 1 && if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2]; x2 = alphas_a[prevIdxOffset][i - 2];
} }
scalar_t result = 0.0; scalar_t result = 0.0;
...@@ -91,7 +88,7 @@ __global__ void falign_cuda_step_kernel( ...@@ -91,7 +88,7 @@ __global__ void falign_cuda_step_kernel(
result = x0; result = x0;
backPtrBuffer_a[backPtrBufferLen][i] = 0; backPtrBuffer_a[backPtrBufferLen][i] = 0;
} }
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx]; alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
threadMax = max(threadMax, alphas_a[curIdxOffset][i]); threadMax = max(threadMax, alphas_a[curIdxOffset][i]);
} }
scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max()); scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max());
...@@ -116,12 +113,10 @@ void forced_align_impl( ...@@ -116,12 +113,10 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std:: using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type; conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
auto paths_a = paths.accessor<target_t, 2>(); auto paths_a = paths.accessor<target_t, 1>();
const int batchIndex = const int T = logProbs.size(0); // num frames
0; // TODO: support batch version and use the real batch index const int N = logProbs.size(1); // alphabet size
const int T = logProbs.size(1); // num frames const int L = targets.size(0); // label length
const int N = logProbs.size(2); // alphabet size
const int L = targets.size(1); // label length
const int S = 2 * L + 1; const int S = 2 * L + 1;
auto targetsCpu = targets.to(torch::kCPU); auto targetsCpu = targets.to(torch::kCPU);
// backPtrBuffer stores the index offset fthe best path at current position // backPtrBuffer stores the index offset fthe best path at current position
...@@ -149,12 +144,12 @@ void forced_align_impl( ...@@ -149,12 +144,12 @@ void forced_align_impl(
.device(logProbs.device())) .device(logProbs.device()))
.fill_(kNegInfinity); .fill_(kNegInfinity);
// CPU accessors // CPU accessors
auto targetsCpu_a = targetsCpu.accessor<target_t, 2>(); auto targetsCpu_a = targetsCpu.accessor<target_t, 1>();
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>(); auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
// count the number of repeats in label // count the number of repeats in label
int R = 0; int R = 0;
for (int i = 1; i < L; ++i) { for (int i = 1; i < L; ++i) {
if (targetsCpu_a[batchIndex][i] == targetsCpu_a[batchIndex][i - 1]) { if (targetsCpu_a[i] == targetsCpu_a[i - 1]) {
++R; ++R;
} }
} }
...@@ -174,16 +169,14 @@ void forced_align_impl( ...@@ -174,16 +169,14 @@ void forced_align_impl(
if (t > 0) { if (t > 0) {
if (T - t <= L + R) { if (T - t <= L + R) {
if ((start % 2 == 1) && if ((start % 2 == 1) &&
(targetsCpu_a[batchIndex][start / 2] != (targetsCpu_a[start / 2] != targetsCpu_a[start / 2 + 1])) {
targetsCpu_a[batchIndex][start / 2 + 1])) {
start = start + 1; start = start + 1;
} }
start = start + 1; start = start + 1;
} }
if (t <= L + R) { if (t <= L + R) {
if ((end % 2 == 0) && (end < 2 * L) && if ((end % 2 == 0) && (end < 2 * L) &&
(targetsCpu_a[batchIndex][end / 2 - 1] != (targetsCpu_a[end / 2 - 1] != targetsCpu_a[end / 2])) {
targetsCpu_a[batchIndex][end / 2])) {
end = end + 1; end = end + 1;
} }
end = end + 1; end = end + 1;
...@@ -191,8 +184,8 @@ void forced_align_impl( ...@@ -191,8 +184,8 @@ void forced_align_impl(
} }
falign_cuda_step_kernel<scalar_t, target_t> falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>( <<<1, kNumThreads, 0, defaultStream>>>(
logProbs.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(), logProbs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 2, torch::RestrictPtrTraits>(), targets.packed_accessor32<target_t, 1, torch::RestrictPtrTraits>(),
T, T,
L, L,
N, N,
...@@ -236,9 +229,8 @@ void forced_align_impl( ...@@ -236,9 +229,8 @@ void forced_align_impl(
: S - 2; : S - 2;
int indexScores = 0; int indexScores = 0;
for (int t = T - 1; t >= 0; --t) { for (int t = T - 1; t >= 0; --t) {
auto lbl_idx = auto lbl_idx = ltrIdx % 2 == 0 ? blank : targetsCpu_a[ltrIdx / 2];
ltrIdx % 2 == 0 ? blank : targetsCpu_a[batchIndex][ltrIdx / 2]; paths_a[t] = lbl_idx;
paths_a[batchIndex][t] = lbl_idx;
++indexScores; ++indexScores;
ltrIdx -= backPtrCpu_a[t][ltrIdx]; ltrIdx -= backPtrCpu_a[t][ltrIdx];
} }
...@@ -266,36 +258,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -266,36 +258,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logProbs.dim() == 3, logProbs.dim() != 3,
"log_probs must be 3-D (batch_size, input length, num classes)"); "3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK( TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK( TORCH_CHECK(
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK( TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK( TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK( TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1), blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)"); "blank must be within [0, num classes)");
TORCH_CHECK( TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(), logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(), targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch"); "target length mismatch");
auto B = logProbs.size(0); auto T = logProbs.size(0); // num frames
auto T = logProbs.size(1); // num frames
auto paths = torch::zeros( auto paths = torch::zeros(
{B, T}, {T}, torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] { logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) { if (targets.scalar_type() == torch::kInt64) {
...@@ -309,10 +295,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute( ...@@ -309,10 +295,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple( return std::make_tuple(
paths.to(logProbs.device()), paths.to(logProbs.device()),
logProbs.index( logProbs.index(
{torch::indexing::Slice(), {torch::linspace(
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())), 0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths.index({0})})); paths}));
} }
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
......
...@@ -2508,12 +2508,12 @@ def forced_align( ...@@ -2508,12 +2508,12 @@ def forced_align(
Args: Args:
log_probs (torch.Tensor): log probability of CTC emission output. log_probs (torch.Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, Tensor of shape `(T, C)`. where `T` is the input length,
`C` is the number of characters in alphabet including blank. `C` is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor of shape `(B, L)`, targets (torch.Tensor): Target sequence. Tensor of shape `(L,)`,
where `L` is the target length. where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 0-D Tensor (scalar).
target_lengths (torch.Tensor): Lengths of the targets. 1-D Tensor of shape `(B,)`. target_lengths (torch.Tensor): Lengths of the targets. 0-D Tensor (scalar).
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns: Returns:
...@@ -2531,9 +2531,6 @@ def forced_align( ...@@ -2531,9 +2531,6 @@ def forced_align(
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens. where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`. For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size``==1.
""" """
if blank in targets: if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.") raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment