"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "52d4449810c8e13eb22b57e706e0e03806247da2"
Commit 15a3d1cd authored by pkufool's avatar pkufool
Browse files

Fix pruning bounds

parent 134c1bcc
...@@ -134,6 +134,8 @@ def get_rnnt_logprobs( ...@@ -134,6 +134,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
# subtracting am_max and lm_max is to ensure the probs are in a good range # subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow. # to do exp() without causing underflow or overflow.
...@@ -331,6 +333,8 @@ def get_rnnt_logprobs_joint( ...@@ -331,6 +333,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape (B, T, S1, C) = logits.shape
S = S1 - 1 S = S1 - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1)) normalizers = normalizers.permute((0, 2, 1))
...@@ -478,7 +482,9 @@ def _adjust_pruning_lower_bound( ...@@ -478,7 +482,9 @@ def _adjust_pruning_lower_bound(
) )
return s_begin return s_begin
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges( def get_rnnt_prune_ranges(
px_grad: torch.Tensor, px_grad: torch.Tensor,
py_grad: torch.Tensor, py_grad: torch.Tensor,
...@@ -505,8 +511,8 @@ def get_rnnt_prune_ranges( ...@@ -505,8 +511,8 @@ def get_rnnt_prune_ranges(
of symbols given a particular frame. of symbols given a particular frame.
Note: Note:
For the generated tensor ranges, ranges[:, 0] is a monotonic increasing For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
tensor from 0 to `len(symbols)` and it satisfies is a monotonic increasing tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols. symbols.
...@@ -529,33 +535,43 @@ def get_rnnt_prune_ranges( ...@@ -529,33 +535,43 @@ def get_rnnt_prune_ranges(
(B, S, T1) = px_grad.shape (B, S, T1) = px_grad.shape
T = py_grad.shape[-1] T = py_grad.shape[-1]
assert T1 in [T, T + 1] assert T1 in [T, T + 1]
S1 = S + 1
assert py_grad.shape == (B, S + 1, T) assert py_grad.shape == (B, S + 1, T)
assert boundary.shape == (B, 4) assert boundary.shape == (B, 4)
assert s_range >= 1
assert S >= 1
assert T >= S
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``.
if s_range > S: if s_range > S:
s_range = S s_range = S + 1
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device) if T1 == T:
py_pad = torch.zeros( assert (
(B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device s_range >= 1
) ), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2)
tot_grad = (
torch.cat((px_grad, px_pad), dim=1) + py_grad_padded
) # (B, S + 1, T1)
tot_grad = torch.cat( else:
( assert (
torch.zeros( s_range >= 2
(B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device ), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
),
tot_grad, blk_grad = torch.as_strided(
), py_grad, (B, S1 - s_range + 1, s_range, T), (S1 * T, T, T, 1)
dim=1,
) )
tot_grad = torch.cumsum(tot_grad, dim=1) # (B, S1 - s_range + 1, T)
diff_grad = tot_grad[:, s_range:, :] - tot_grad[:, 0:-s_range, :] blk_sum_grad = torch.sum(blk_grad, axis=2)
s_begin = torch.argmax(diff_grad, dim=1)
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
# (B, S1, T)
px_grad_pad = torch.cat((px_pad, px_grad), dim=1)
# (B, S1 - s_range + 1, T)
final_grad = blk_sum_grad - px_grad_pad[:, : S1 - s_range + 1, :T]
# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
s_begin = s_begin[:, :T] s_begin = s_begin[:, :T]
# Handle the values of s_begin in padding positions. # Handle the values of s_begin in padding positions.
...@@ -568,7 +584,7 @@ def get_rnnt_prune_ranges( ...@@ -568,7 +584,7 @@ def get_rnnt_prune_ranges(
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range` # handle the cases when `len(symbols) < s_range`
s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0) s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding) s_begin = torch.where(mask, s_begin, s_begin_padding)
...@@ -578,9 +594,11 @@ def get_rnnt_prune_ranges( ...@@ -578,9 +594,11 @@ def get_rnnt_prune_ranges(
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because # the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame. # it only emits one symbol per frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range) s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange( ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
s_range, device=px_grad.device s_range, device=px_grad.device
) )
return ranges return ranges
...@@ -699,6 +717,8 @@ def get_rnnt_logprobs_pruned( ...@@ -699,6 +717,8 @@ def get_rnnt_logprobs_pruned(
(B, T, s_range, C) = logits.shape (B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range) assert ranges.shape == (B, T, s_range)
(B, S) = symbols.shape (B, S) = symbols.shape
assert S >= 1
assert T >= S
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
...@@ -955,6 +975,8 @@ def get_rnnt_logprobs_smoothed( ...@@ -955,6 +975,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
# Caution: some parts of this code are a little less clear than they could # Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that # be due to optimizations. In particular it may not be totally obvious that
......
...@@ -120,11 +120,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -120,11 +120,11 @@ class TestRnntLoss(unittest.TestCase):
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1) logits = am.unsqueeze(2) + lm.unsqueeze(1)
# test rnnt_loss # test rnnt_loss
m = fast_rnnt.rnnt_loss( m = fast_rnnt.rnnt_loss(
logits=probs, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=None, boundary=None,
...@@ -137,7 +137,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -137,7 +137,7 @@ class TestRnntLoss(unittest.TestCase):
import torchaudio.functional import torchaudio.functional
m = torchaudio.functional.rnnt_loss( m = torchaudio.functional.rnnt_loss(
logits=probs, logits=logits,
targets=symbols.int(), targets=symbols.int(),
logit_lengths=torch.tensor( logit_lengths=torch.tensor(
[T] * B, dtype=torch.int32, device=device [T] * B, dtype=torch.int32, device=device
...@@ -176,9 +176,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -176,9 +176,9 @@ class TestRnntLoss(unittest.TestCase):
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1) logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss( m = fast_rnnt.rnnt_loss(
logits=probs, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=None, boundary=None,
...@@ -255,9 +255,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -255,9 +255,9 @@ class TestRnntLoss(unittest.TestCase):
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1) logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss( m = fast_rnnt.rnnt_loss(
logits=probs, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
...@@ -270,7 +270,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -270,7 +270,7 @@ class TestRnntLoss(unittest.TestCase):
import torchaudio.functional import torchaudio.functional
m = torchaudio.functional.rnnt_loss( m = torchaudio.functional.rnnt_loss(
logits=probs, logits=logits,
targets=symbols.int(), targets=symbols.int(),
logit_lengths=boundary[:, 3].int(), logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(), target_lengths=boundary[:, 2].int(),
...@@ -292,9 +292,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -292,9 +292,9 @@ class TestRnntLoss(unittest.TestCase):
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1) logits = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss( m = fast_rnnt.rnnt_loss(
logits=probs, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
...@@ -345,32 +345,32 @@ class TestRnntLoss(unittest.TestCase): ...@@ -345,32 +345,32 @@ class TestRnntLoss(unittest.TestCase):
symbols = symbols_.to(device) symbols = symbols_.to(device)
boundary = boundary_.to(device) boundary = boundary_.to(device)
logprobs = am.unsqueeze(2) + lm.unsqueeze(1) logits = am.unsqueeze(2) + lm.unsqueeze(1)
logprobs.requires_grad_() logits.requires_grad_()
k2_loss = fast_rnnt.rnnt_loss( fast_loss = fast_rnnt.rnnt_loss(
logits=logprobs, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
) )
k2_grad = torch.autograd.grad(k2_loss, logprobs) fast_grad = torch.autograd.grad(fast_loss, logits)
k2_grad = k2_grad[0] fast_grad = fast_grad[0]
logprobs2 = logprobs.detach().clone().float() logits2 = logits.detach().clone().float()
logprobs2.requires_grad_() logits2.requires_grad_()
torch_loss = torchaudio.functional.rnnt_loss( torch_loss = torchaudio.functional.rnnt_loss(
logits=logprobs2, logits=logits2,
targets=symbols.int(), targets=symbols.int(),
logit_lengths=boundary[:, 3].int(), logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(), target_lengths=boundary[:, 2].int(),
blank=termination_symbol, blank=termination_symbol,
) )
torch_grad = torch.autograd.grad(torch_loss, logprobs2) torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0] torch_grad = torch_grad[0]
assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2) assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)
def test_rnnt_loss_smoothed(self): def test_rnnt_loss_smoothed(self):
B = 1 B = 1
...@@ -450,14 +450,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -450,14 +450,13 @@ class TestRnntLoss(unittest.TestCase):
lm = lm_.to(device) lm = lm_.to(device)
symbols = symbols_.to(device) symbols = symbols_.to(device)
boundary = boundary_.to(device) boundary = boundary_.to(device)
t_am = am.unsqueeze(2).float() logits = am.unsqueeze(2) + lm.unsqueeze(1)
t_lm = lm.unsqueeze(1).float() logits = logits.float()
t_prob = t_am + t_lm
# nonlinear transform # nonlinear transform
t_prob = torch.sigmoid(t_prob) logits = torch.sigmoid(logits)
k2_loss = fast_rnnt.rnnt_loss( fast_loss = fast_rnnt.rnnt_loss(
logits=t_prob, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
...@@ -465,11 +464,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -465,11 +464,11 @@ class TestRnntLoss(unittest.TestCase):
) )
print( print(
f"unpruned rnnt loss with modified {modified} : {k2_loss}" f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
) )
# pruning # pruning
k2_simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm, lm=lm,
am=am, am=am,
symbols=symbols, symbols=symbols,
...@@ -488,15 +487,15 @@ class TestRnntLoss(unittest.TestCase): ...@@ -488,15 +487,15 @@ class TestRnntLoss(unittest.TestCase):
s_range=r, s_range=r,
) )
# (B, T, r, C) # (B, T, r, C)
am_p, lm_p = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
t_prob_p = am_p + lm_p logits = pruned_am + pruned_lm
# nonlinear transform # nonlinear transform
t_prob_p = torch.sigmoid(t_prob_p) logits = torch.sigmoid(logits)
pruned_loss = fast_rnnt.rnnt_loss_pruned( pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=t_prob_p, logits=logits,
symbols=symbols, symbols=symbols,
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
...@@ -504,8 +503,104 @@ class TestRnntLoss(unittest.TestCase): ...@@ -504,8 +503,104 @@ class TestRnntLoss(unittest.TestCase):
modified=modified, modified=modified,
reduction="none", reduction="none",
) )
print(f"pruning loss with range {r} : {pruned_loss}") print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
def test_rnnt_loss_pruned_small_symbols_number(self):
B = 2
T = 20
S = 3
C = 10
frames = torch.randint(S + 1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)
am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()
# nonlinear transform
logits = torch.sigmoid(logits)
loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
return_grad=True,
reduction="none",
)
S0 = 2
if modified:
S0 = 1
for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
# nonlinear transform
logits = torch.sigmoid(logits)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment