"docs_zh_CN/vscode:/vscode.git/clone" did not exist on "f69a3a44e3f54537a545659c62001247b079a62e"
Commit cc164478 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update forced_align method to only support batch Tensors (#3433)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3433

Current design of forced_align accept 2D Tensor for `log_probs` and 1D Tensor for `targets`. To make the API simple, the PR make changes to only support batch Tensors (3D Tensor for `log_probs` and 2D Tensor for `targets`).

Reviewed By: mthrok

Differential Revision: D46657526

fbshipit-source-id: af17ec3f92f1a2c46dba91c6db2488a11de36f89
parent 163157d3
......@@ -96,7 +96,7 @@ with torch.inference_mode():
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
emission = emissions.cpu().detach()
dictionary = {c: i for i, c in enumerate(labels)}
print(dictionary)
......@@ -107,7 +107,7 @@ print(dictionary)
# ^^^^^^^^^^^^^
#
plt.imshow(emission.T)
plt.imshow(emission[0].T)
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
......@@ -205,27 +205,27 @@ def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32)
input_lengths = torch.tensor(emission.shape[0])
target_lengths = torch.tensor(targets.shape[0])
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert len(frame_alignment) == input_lengths.item()
assert len(targets) == target_lengths.item()
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
token_index = -1
prev_hyp = 0
for i in range(len(frame_alignment)):
if frame_alignment[i].item() == 0:
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
prev_hyp = 0
continue
if frame_alignment[i].item() != prev_hyp:
if frame_alignment[0][i].item() != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[i].exp().item()))
prev_hyp = frame_alignment[i].item()
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
return frames, frame_alignment, frame_scores
......@@ -390,7 +390,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
plt.rcParams.update({"font.size": 30})
# The original waveform
ratio = waveform.size(0) / input_lengths
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
......@@ -414,8 +414,8 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
plot_alignments(
segments,
word_segments,
waveform[0],
emission.shape[0],
waveform,
emission.shape[1],
1,
)
plt.show()
......@@ -428,7 +428,7 @@ plt.show()
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / len(frame_alignment)
ratio = waveform.size(1) / frame_alignment.size(1)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
......@@ -511,19 +511,19 @@ with torch.inference_mode():
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions[0].detach()
emission = emissions.detach()
# Extend the dictionary to include the <star> token.
dictionary["*"] = 29
assert len(dictionary) == emission.shape[1]
assert len(dictionary) == emission.shape[2]
def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|")
plot_alignments(segments, word_segments, waveform[0], emission.shape[0], 1)
plot_alignments(segments, word_segments, waveform, emission.shape[1], 1)
plt.show()
return word_segments, frame_alignment
......
......@@ -90,27 +90,27 @@ def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32)
input_lengths = torch.tensor(emission.shape[0])
target_lengths = torch.tensor(targets.shape[0])
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
# This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert len(frame_alignment) == input_lengths.item()
assert len(targets) == target_lengths.item()
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
token_index = -1
prev_hyp = 0
for i in range(len(frame_alignment)):
if frame_alignment[i].item() == 0:
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
prev_hyp = 0
continue
if frame_alignment[i].item() != prev_hyp:
if frame_alignment[0][i].item() != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[i].exp().item()))
prev_hyp = frame_alignment[i].item()
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
......@@ -150,7 +150,7 @@ def compute_alignments(transcript, dictionary, emission):
i2 += 1
i3 += 1
num_frames = len(frame_alignment)
num_frames = frame_alignment.shape[1]
return segments, words, num_frames
......@@ -160,7 +160,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
plt.rcParams.update({"font.size": 30})
# The original waveform
ratio = waveform.size(0) / input_lengths
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
......@@ -249,12 +249,12 @@ def get_emission(waveform):
emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
emission = emissions.cpu().detach()
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions[0].detach()
emission = emissions.detach()
return emission, waveform
......@@ -347,12 +347,12 @@ speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
assert len(dictionary) == emission.shape[1]
assert len(dictionary) == emission.shape[2]
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
......@@ -482,13 +482,14 @@ text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
......@@ -557,7 +558,7 @@ emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
......@@ -660,7 +661,7 @@ emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
......@@ -785,7 +786,7 @@ emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
......
......@@ -1116,55 +1116,60 @@ class Functional(TestBaseMixin):
@parameterized.expand(
[
([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32),
([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32),
([3, 3, 3], [3, 5, 3, 5, 3], torch.int64),
([0, 1, 2], [0, 1, 1, 1, 2], torch.int64),
([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32),
([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32),
([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64),
([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64),
]
)
def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor(
[
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
],
dtype=self.dtype,
device=self.device,
)
blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor(
[torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])],
[torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])],
dtype=emission.dtype,
device=self.device,
)
).unsqueeze(0)
log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((targets.shape[0]))
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device)
blank = 5
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int()
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
......@@ -1175,40 +1180,42 @@ class Functional(TestBaseMixin):
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int()
with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"):
with self.assertRaisesRegex(
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"):
log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"):
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"):
input_lengths = torch.tensor([log_probs.shape[0]], device=self.device)
target_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((10000), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([10000], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((10000))
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([10000], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device)
targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device)
blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
......@@ -1238,14 +1245,14 @@ class FunctionalCUDAOnly(TestBaseMixin):
@nested_params(
[torch.half, torch.float, torch.double],
[torch.int32, torch.int64],
[(50, 100), (100, 100)],
[(10,), (40,), (45,)],
[(1, 50, 100), (1, 100, 100)],
[(1, 10), (1, 40), (1, 45)],
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda()
......
......@@ -17,8 +17,10 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto T = logProbs.size(0);
const auto L = targets.size(0);
const auto batchIndex =
0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
const auto L = targets.size(1);
const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty(
{2, S},
......@@ -27,14 +29,14 @@ void forced_align_impl(
.dtype(logProbs.dtype()))
.fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 2>();
auto targets_a = targets.accessor<target_t, 1>();
auto paths_a = paths.accessor<target_t, 1>();
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
auto targets_a = targets.accessor<target_t, 2>();
auto paths_a = paths.accessor<target_t, 2>();
auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>();
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[i] == targets_a[i - 1]) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
++R;
}
}
......@@ -49,20 +51,22 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[0][i] = logProbs_a[0][labelIdx];
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
}
for (auto t = 1; t < T; t++) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
targets_a[start / 2] != targets_a[start / 2 + 1]) {
targets_a[batchIndex][start / 2] !=
targets_a[batchIndex][start / 2 + 1]) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L &&
targets_a[end / 2 - 1] != targets_a[end / 2]) {
targets_a[batchIndex][end / 2 - 1] !=
targets_a[batchIndex][end / 2]) {
end = end + 1;
}
end = end + 1;
......@@ -75,7 +79,7 @@ void forced_align_impl(
}
if (start == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
backPtr_a[t][0] = 0;
startloop += 1;
}
......@@ -85,13 +89,14 @@ void forced_align_impl(
auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity();
auto labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
// In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
......@@ -105,7 +110,7 @@ void forced_align_impl(
result = x0;
backPtr_a[t][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
}
}
auto idx1 = (T - 1) % 2;
......@@ -113,8 +118,8 @@ void forced_align_impl(
// path stores the token index for each time step after force alignment.
auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx];
}
......@@ -142,30 +147,35 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
logProbs.dim() == 3,
"log_probs must be 3-D (batch_size, input length, num classes)");
TORCH_CHECK(
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(),
logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(),
targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch");
const auto T = logProbs.size(0);
const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros(
{T},
{B, T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
......@@ -180,9 +190,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple(
paths,
logProbs.index(
{torch::linspace(
{torch::indexing::Slice(),
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths}));
paths.index({0})}));
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
......
......@@ -18,9 +18,9 @@ namespace alignment {
namespace gpu {
template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel(
const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
logProbs_a,
const torch::PackedTensorAccessor32<target_t, 1, torch::RestrictPtrTraits>
const torch::PackedTensorAccessor32<target_t, 2, torch::RestrictPtrTraits>
targets_a,
const int T,
const int L,
......@@ -36,6 +36,8 @@ __global__ void falign_cuda_step_kernel(
torch::PackedTensorAccessor32<int8_t, 2, torch::RestrictPtrTraits>
backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
int S = 2 * L + 1;
int curIdxOffset = (t % 2); // current time step frame for alpha
int prevIdxOffset = ((t - 1) % 2); // previous time step frame for alpha
......@@ -49,8 +51,8 @@ __global__ void falign_cuda_step_kernel(
__syncthreads();
if (t == 0) {
for (unsigned int i = start + threadIdx.x; i < end; i += blockDim.x) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[0][labelIdx];
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[curIdxOffset][i] = logProbs_a[batchIndex][0][labelIdx];
}
return;
}
......@@ -62,7 +64,7 @@ __global__ void falign_cuda_step_kernel(
threadMax = kNegInfinity;
if (start == 0 && threadIdx.x == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[t][blank];
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
threadMax = max(threadMax, alphas_a[curIdxOffset][0]);
backPtrBuffer_a[backPtrBufferLen][0] = 0;
}
......@@ -73,8 +75,9 @@ __global__ void falign_cuda_step_kernel(
scalar_t x0 = alphas_a[prevIdxOffset][i];
scalar_t x1 = alphas_a[prevIdxOffset][i - 1];
scalar_t x2 = kNegInfinity;
int labelIdx = (i % 2 == 0) ? blank : targets_a[i / 2];
if (i % 2 != 0 && i != 1 && targets_a[i / 2] != targets_a[i / 2 - 1]) {
int labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
......@@ -88,7 +91,7 @@ __global__ void falign_cuda_step_kernel(
result = x0;
backPtrBuffer_a[backPtrBufferLen][i] = 0;
}
alphas_a[curIdxOffset][i] = result + logProbs_a[t][labelIdx];
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
threadMax = max(threadMax, alphas_a[curIdxOffset][i]);
}
scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max());
......@@ -113,10 +116,12 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
auto paths_a = paths.accessor<target_t, 1>();
const int T = logProbs.size(0); // num frames
const int N = logProbs.size(1); // alphabet size
const int L = targets.size(0); // label length
auto paths_a = paths.accessor<target_t, 2>();
const int batchIndex =
0; // TODO: support batch version and use the real batch index
const int T = logProbs.size(1); // num frames
const int N = logProbs.size(2); // alphabet size
const int L = targets.size(1); // label length
const int S = 2 * L + 1;
auto targetsCpu = targets.to(torch::kCPU);
// backPtrBuffer stores the index offset fthe best path at current position
......@@ -144,12 +149,12 @@ void forced_align_impl(
.device(logProbs.device()))
.fill_(kNegInfinity);
// CPU accessors
auto targetsCpu_a = targetsCpu.accessor<target_t, 1>();
auto targetsCpu_a = targetsCpu.accessor<target_t, 2>();
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
// count the number of repeats in label
int R = 0;
for (int i = 1; i < L; ++i) {
if (targetsCpu_a[i] == targetsCpu_a[i - 1]) {
if (targetsCpu_a[batchIndex][i] == targetsCpu_a[batchIndex][i - 1]) {
++R;
}
}
......@@ -169,14 +174,16 @@ void forced_align_impl(
if (t > 0) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
(targetsCpu_a[start / 2] != targetsCpu_a[start / 2 + 1])) {
(targetsCpu_a[batchIndex][start / 2] !=
targetsCpu_a[batchIndex][start / 2 + 1])) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if ((end % 2 == 0) && (end < 2 * L) &&
(targetsCpu_a[end / 2 - 1] != targetsCpu_a[end / 2])) {
(targetsCpu_a[batchIndex][end / 2 - 1] !=
targetsCpu_a[batchIndex][end / 2])) {
end = end + 1;
}
end = end + 1;
......@@ -184,8 +191,8 @@ void forced_align_impl(
}
falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>(
logProbs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 1, torch::RestrictPtrTraits>(),
logProbs.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
targets.packed_accessor32<target_t, 2, torch::RestrictPtrTraits>(),
T,
L,
N,
......@@ -229,8 +236,9 @@ void forced_align_impl(
: S - 2;
int indexScores = 0;
for (int t = T - 1; t >= 0; --t) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targetsCpu_a[ltrIdx / 2];
paths_a[t] = lbl_idx;
auto lbl_idx =
ltrIdx % 2 == 0 ? blank : targetsCpu_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
++indexScores;
ltrIdx -= backPtrCpu_a[t][ltrIdx];
}
......@@ -258,30 +266,36 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logProbs.dim() != 3,
"3-D tensor is not yet supported for log_probs, please provide 2-D tensor.")
logProbs.dim() == 3,
"log_probs must be 3-D (batch_size, input length, num classes)");
TORCH_CHECK(
targets.dim() != 2,
"2-D tensor is not yet supported for targets, please provide 1-D tensor.")
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
TORCH_CHECK(
logProbs.dim() == 2, "log_probs must be 2-D (input length, num classes)");
TORCH_CHECK(targets.dim() == 1, "targets must be 1-D (target length,)");
TORCH_CHECK(inputLengths.dim() == 0, "input_lengths must be 0-D");
TORCH_CHECK(targetLengths.dim() == 0, "target_lengths must be 0-D");
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
TORCH_CHECK(
logProbs.size(0) == 1,
"The batch dimension for log_probs must be 1 at the current version.")
TORCH_CHECK(
targets.size(0) == 1,
"The batch dimension for targets must be 1 at the current version.")
TORCH_CHECK(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");
TORCH_CHECK(
logProbs.size(0) == at::max(inputLengths).item().toInt(),
logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(0) == at::max(targetLengths).item().toInt(),
targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch");
auto T = logProbs.size(0); // num frames
auto B = logProbs.size(0);
auto T = logProbs.size(1); // num frames
auto paths = torch::zeros(
{T}, torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
{B, T},
torch::TensorOptions().device(torch::kCPU).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
......@@ -295,9 +309,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
return std::make_tuple(
paths.to(logProbs.device()),
logProbs.index(
{torch::linspace(
{torch::indexing::Slice(),
torch::linspace(
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
paths}));
paths.index({0})}));
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
......
......@@ -2511,12 +2511,12 @@ def forced_align(
Args:
log_probs (torch.Tensor): log probability of CTC emission output.
Tensor of shape `(T, C)`. where `T` is the input length,
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (torch.Tensor): Target sequence. Tensor of shape `(L,)`,
targets (torch.Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 0-D Tensor (scalar).
target_lengths (torch.Tensor): Lengths of the targets. 0-D Tensor (scalar).
input_lengths (torch.Tensor): Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (torch.Tensor): Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
......@@ -2534,6 +2534,9 @@ def forced_align(
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size``==1.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment