Commit bbc13b9a authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Revert D46126226: Update forced_align method to only support batch Tensors

Differential Revision:
D46126226

Original commit changeset: 42cb52b19d91

Original Phabricator Diff: D46126226

fbshipit-source-id: 372b2526d9e196e37e014f1556bf117d29bb1ac6
parent 5f17d81c
......@@ -1116,60 +1116,55 @@ class Functional(TestBaseMixin):
@parameterized.expand(
[
([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32),
([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32),
([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64),
([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64),
([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32),
([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32),
([3, 3, 3], [3, 5, 3, 5, 3], torch.int64),
([0, 1, 2], [0, 1, 1, 1, 2], torch.int64),
]
)
def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor(
[
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
],
dtype=self.dtype,
device=self.device,
)
blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor(
[torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])],
[torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])],
dtype=emission.dtype,
device=self.device,
).unsqueeze(0)
)
log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((targets.shape[0]))
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device)
blank = 5
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device)
targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int()
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
......@@ -1180,42 +1175,40 @@ class Functional(TestBaseMixin):
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int()
with self.assertRaisesRegex(
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"):
log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"):
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[0]], device=self.device)
target_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"):
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([10000], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
input_lengths = torch.tensor((10000), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([10000], device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((10000))
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device)
targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device)
targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device)
blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
......@@ -1245,14 +1238,14 @@ class FunctionalCUDAOnly(TestBaseMixin):
@nested_params(
[torch.half, torch.float, torch.double],
[torch.int32, torch.int64],
[(1, 50, 100), (1, 100, 100)],
[(1, 10), (1, 40), (1, 45)],
[(50, 100), (100, 100)],
[(10,), (40,), (45,)],
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda()
......
......@@ -17,10 +17,8 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto batchIndex =
0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
const auto L = targets.size(1);
const auto T = logProbs.size(0);
const auto L = targets.size(0);
const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty(
{2, S},
......@@ -29,14 +27,14 @@ void forced_align_impl(
.dtype(logProbs.dtype()))
.fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
auto targets_a = targets.accessor<target_t, 2>();
auto paths_a = paths.accessor<target_t, 2>();
auto logProbs_a = logProbs.accessor<scalar_t, 2>();
auto targets_a = targets.accessor<target_t, 1>();
auto paths_a = paths.accessor<target_t, 1>();
auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>();
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
if (targets_a[i] == targets_a[i - 1]) {
++R;
}
}
......@@ -51,22 +49,20 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[0][i] = logProbs_a[0][labelIdx];
}
for (auto t = 1; t < T; t++) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
targets_a[batchIndex][start / 2] !=
targets_a[batchIndex][start / 2 + 1]) {
targets_a[start / 2] != targets_a[start / 2 + 1]) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L &&
targets_a[batchIndex][end / 2 - 1] !=
targets_a[batchIndex][end / 2]) {
targets_a[end / 2 - 1] != targets_a[end / 2]) {
end = end + 1;
}
end = end + 1;
......@@ -79,7 +75,7 @@ void forced_align_impl(
}
if (start == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
backPtr_a[t][0] = 0;
startloop += 1;
}
......@@ -89,14 +85,13 @@ void forced_align_impl(
auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity();
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
// In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
......@@ -110,7 +105,7 @@ void forced_align_impl(
result = x0;
backPtr_a[t][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
}
}
auto idx1 = (T - 1) % 2;
......@@ -118,8 +113,8 @@ void forced_align_impl(
// path stores the token index for each time step after force alignment.
auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx];
}
......@@ -147,35 +142,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() == 3,
"log_probs must be 3-D (batch_size, input length, num classes)");
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK(
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(),
logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(),
targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch");
const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
const auto T = logProbs.size(0);
auto paths = torch::zeros(
{B, T},
{T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
......@@ -190,10 +180,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple(
paths,
logProbs.index(
{torch::indexing::Slice(),
torch::linspace(
{torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths.index({0})}));
paths}));
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
......
......@@ -18,9 +18,9 @@ namespace alignment {
namespace gpu {
template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel(
const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
logProbs_a,
const torch::PackedTensorAccessor32<target_t, 2, torch::RestrictPtrTraits>
const torch::PackedTensorAccessor32<target_t, 1, torch::RestrictPtrTraits>
targets_a,
const int T,
const int L,
......@@ -36,8 +36,6 @@ __global__ void falign_cuda_step_kernel(
torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits>
backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
int S = 2 * L + 1;
int curIdxOffset = (t % 2); // current time step frame for alpha
int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha
......@@ -51,8 +49,8 @@ __global__ void falign_cuda_step_kernel(
__syncthreads();
if (t == 0) {
for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[batchIndex][0][labelIdx];
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[0][labelIdx];
}
return;
}
......@@ -64,7 +62,7 @@ __global__ void falign_cuda_step_kernel(
threadMax = kNegInfinity;
if (start == 0 && threadIdx.x == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
threadMax = max(threadMax, alphas_a[curIdxOffset][0]);
backPtrBuffer_a[backPtrBufferLen][0] = 0;
}
......@@ -75,9 +73,8 @@ __global__ void falign_cuda_step_kernel(
scalar_t x0 = alphas_a[prevIdxOffset][i];
scalar_t x1 = alphas_a[prevIdxOffset][i - 1];
scalar_t x2 = kNegInfinity;
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
......@@ -91,7 +88,7 @@ __global__ void falign_cuda_step_kernel(
result = x0;
backPtrBuffer_a[backPtrBufferLen][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
threadMax = max(threadMax, alphas_a[curIdxOffset][i]);
}
scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max());
......@@ -116,12 +113,10 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
auto paths_a = paths.accessor<target_t, 2>();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
const int T = logProbs.size(1); // num frames
const int N = logProbs.size(2); // alphabet size
const int L = targets.size(1); // label length
auto paths_a = paths.accessor<target_t, 1>();
const int T = logProbs.size(0); // num frames
const int N = logProbs.size(1); // alphabet size
const int L = targets.size(0); // label length
const int S = 2 * L + 1;
auto targetsCpu = targets.to(torch::kCPU);
// backPtrBuffer stores the index offset fthe best path at current position
......@@ -149,12 +144,12 @@ void forced_align_impl(
.device(logProbs.device()))
.fill_(kNegInfinity);
// CPU accessors
auto targetsCpu_a = targetsCpu.accessor<target_t, 2>();
auto targetsCpu_a = targetsCpu.accessor<target_t, 1>();
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
// count the number of repeats in label
int R = 0;
for (int i = 1; i < L; ++i) {
if (targetsCpu_a[batchIndex][i] == targetsCpu_a[batchIndex][i - 1]) {
if (targetsCpu_a[i] == targetsCpu_a[i - 1]) {
++R;
}
}
......@@ -174,16 +169,14 @@ void forced_align_impl(
if (t > 0) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
(targetsCpu_a[batchIndex][start / 2] !=
targetsCpu_a[batchIndex][start / 2 + 1])) {
(targetsCpu_a[start / 2] != targetsCpu_a[start / 2 + 1])) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if ((end % 2 == 0) && (end < 2 * L) &&
(targetsCpu_a[batchIndex][end / 2 - 1] !=
targetsCpu_a[batchIndex][end / 2])) {
(targetsCpu_a[end / 2 - 1] != targetsCpu_a[end / 2])) {
end = end + 1;
}
end = end + 1;
......@@ -191,8 +184,8 @@ void forced_align_impl(
}
falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>(
logProbs.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 2, torch::RestrictPtrTraits>(),
logProbs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 1, torch::RestrictPtrTraits>(),
T,
L,
N,
......@@ -236,9 +229,8 @@ void forced_align_impl(
: S - 2;
int indexScores = 0;
for (int t = T - 1; t >= 0; --t) {
auto lbl_idx =
ltrIdx % 2 == 0 ? blank : targetsCpu_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targetsCpu_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
++indexScores;
ltrIdx -= backPtrCpu_a[t][ltrIdx];
}
......@@ -266,36 +258,30 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() == 3,
"log_probs must be 3-D (batch_size, input length, num classes)");
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
TORCH_CHECK(
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(),
logProbs.size(0) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(),
targets.size(0) == at::max(targetLengths).item().toInt(),
"target length mismatch");
auto B = logProbs.size(0);
auto T = logProbs.size(1); // num frames
auto T = logProbs.size(0); // num frames
auto paths = torch::zeros(
{B, T},
torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
{T}, torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
......@@ -309,10 +295,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple(
paths.to(logProbs.device()),
logProbs.index(
{torch::indexing::Slice(),
torch::linspace(
{torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths.index({0})}));
paths}));
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
......
......@@ -2508,12 +2508,12 @@ def forced_align(
Args:
log_probs (torch.Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
Tensor of shape `(T, C)`. where `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor of shape `(B, L)`,
targets (torch.Tensor): Target sequence. Tensor of shape `(L,)`,
where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (torch.Tensor): Lengths of the targets. 1-D Tensor of shape `(B,)`.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 0-D Tensor (scalar).
target_lengths (torch.Tensor): Lengths of the targets. 0-D Tensor (scalar).
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
......@@ -2531,9 +2531,6 @@ def forced_align(
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size``==1.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment