"examples/vscode:/vscode.git/clone" did not exist on "f05d75c07605e15354247c56057fb14830235017"
Unverified Commit c268c3d5 authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #11 from pkufool/fix_s_range

Fix pruning bounds
parents 134c1bcc 6afe9951
......@@ -134,6 +134,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
......@@ -331,6 +333,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
......@@ -478,7 +482,9 @@ def _adjust_pruning_lower_bound(
)
return s_begin
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
......@@ -505,8 +511,8 @@ def get_rnnt_prune_ranges(
of symbols given a particular frame.
Note:
For the generated tensor ranges, ranges[:, 0] is a monotonic increasing
tensor from 0 to `len(symbols)` and it satisfies
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
......@@ -529,58 +535,78 @@ def get_rnnt_prune_ranges(
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]
assert T1 in [T, T + 1]
S1 = S + 1
assert py_grad.shape == (B, S + 1, T)
assert boundary.shape == (B, 4)
assert s_range >= 1
assert S >= 1
assert T >= S
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S
s_range = S + 1
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
py_pad = torch.zeros(
(B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device
)
py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2)
tot_grad = (
torch.cat((px_grad, px_pad), dim=1) + py_grad_padded
) # (B, S + 1, T1)
if T1 == T:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
tot_grad = torch.cat(
(
torch.zeros(
(B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device
),
tot_grad,
),
dim=1,
else:
assert (
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
py_grad,
(B, S1 - s_range + 1, s_range, T),
(B_stride, S_stride, S_stride, T_stride),
)
tot_grad = torch.cumsum(tot_grad, dim=1)
diff_grad = tot_grad[:, s_range:, :] - tot_grad[:, 0:-s_range, :]
s_begin = torch.argmax(diff_grad, dim=1)
s_begin = s_begin[:, :T]
# (B, S1 - s_range + 1, T)
blk_sum_grad = torch.sum(blk_grad, axis=2)
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
# (B, S1, T)
px_grad_pad = torch.cat((px_pad, px_grad), dim=1)
# (B, S1 - s_range + 1, T)
final_grad = blk_sum_grad - px_grad_pad[:, : S1 - s_range + 1, :T]
# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# -1 here means we fill the position of the last frame (before padding) with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
# This is to guarantee that we reach the last symbol at last frame (before
# padding).
# The shape of the mask is (B, T), for example, we have a batch containing
# 3 sequences, their lengths are 3, 5, 6 (i.e. B = 3, T = 6), so the mask is
# [[True, True, False, False, False, False],
# [True, True, True, True, False, False],
# [True, True, True, True, True, False]]
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range`
s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0)
s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# `_adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
s_range, device=px_grad.device
)
return ranges
......@@ -699,6 +725,8 @@ def get_rnnt_logprobs_pruned(
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range)
(B, S) = symbols.shape
assert S >= 1
assert T >= S
normalizers = torch.logsumexp(logits, dim=3)
......@@ -955,6 +983,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
......
......@@ -120,11 +120,11 @@ class TestRnntLoss(unittest.TestCase):
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
# test rnnt_loss
m = fast_rnnt.rnnt_loss(
logits=probs,
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
......@@ -137,7 +137,7 @@ class TestRnntLoss(unittest.TestCase):
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
logits=logits,
targets=symbols.int(),
logit_lengths=torch.tensor(
[T] * B, dtype=torch.int32, device=device
......@@ -176,9 +176,9 @@ class TestRnntLoss(unittest.TestCase):
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
......@@ -255,9 +255,9 @@ class TestRnntLoss(unittest.TestCase):
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
......@@ -270,7 +270,7 @@ class TestRnntLoss(unittest.TestCase):
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
logits=logits,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
......@@ -292,9 +292,9 @@ class TestRnntLoss(unittest.TestCase):
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
......@@ -345,32 +345,32 @@ class TestRnntLoss(unittest.TestCase):
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logprobs = am.unsqueeze(2) + lm.unsqueeze(1)
logprobs.requires_grad_()
k2_loss = fast_rnnt.rnnt_loss(
logits=logprobs,
logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits.requires_grad_()
fast_loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
)
k2_grad = torch.autograd.grad(k2_loss, logprobs)
k2_grad = k2_grad[0]
fast_grad = torch.autograd.grad(fast_loss, logits)
fast_grad = fast_grad[0]
logprobs2 = logprobs.detach().clone().float()
logprobs2.requires_grad_()
logits2 = logits.detach().clone().float()
logits2.requires_grad_()
torch_loss = torchaudio.functional.rnnt_loss(
logits=logprobs2,
logits=logits2,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
blank=termination_symbol,
)
torch_grad = torch.autograd.grad(torch_loss, logprobs2)
torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0]
assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)
def test_rnnt_loss_smoothed(self):
B = 1
......@@ -450,14 +450,13 @@ class TestRnntLoss(unittest.TestCase):
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
t_am = am.unsqueeze(2).float()
t_lm = lm.unsqueeze(1).float()
t_prob = t_am + t_lm
logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()
# nonlinear transform
t_prob = torch.sigmoid(t_prob)
k2_loss = fast_rnnt.rnnt_loss(
logits=t_prob,
logits = torch.sigmoid(logits)
fast_loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
......@@ -465,11 +464,11 @@ class TestRnntLoss(unittest.TestCase):
)
print(
f"unpruned rnnt loss with modified {modified} : {k2_loss}"
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
# pruning
k2_simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
......@@ -488,15 +487,15 @@ class TestRnntLoss(unittest.TestCase):
s_range=r,
)
# (B, T, r, C)
am_p, lm_p = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
t_prob_p = am_p + lm_p
logits = pruned_am + pruned_lm
# nonlinear transform
t_prob_p = torch.sigmoid(t_prob_p)
logits = torch.sigmoid(logits)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=t_prob_p,
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
......@@ -504,8 +503,104 @@ class TestRnntLoss(unittest.TestCase):
modified=modified,
reduction="none",
)
print(f"pruning loss with range {r} : {pruned_loss}")
print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
def test_rnnt_loss_pruned_small_symbols_number(self):
B = 2
T = 20
S = 3
C = 10
frames = torch.randint(S + 1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)
am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()
# nonlinear transform
logits = torch.sigmoid(logits)
loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
return_grad=True,
reduction="none",
)
S0 = 2
if modified:
S0 = 1
for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
# nonlinear transform
logits = torch.sigmoid(logits)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment