Commit b0ed23ef authored by pkufool's avatar pkufool
Browse files

Add constrained rnnt

parent dc35168d
This diff is collapsed.
...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1) assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None) m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"): if device == torch.device("cpu"):
expected = -m expected = -m
...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# lm: [B][S+1][C] # lm: [B][S+1][C]
lm = lm_.to(device) lm = lm_.to(device)
...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
) )
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion( m = fast_rnnt.mutual_information_recursion(
...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase): ...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss # compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified: if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional import torchaudio.functional
m = torchaudio.functional.rnnt_loss( m = torchaudio.functional.rnnt_loss(
...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2) torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0] torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self): def test_rnnt_loss_smoothed(self):
B = 1 B = 1
...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase): ...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
print( print(f"Unpruned rnnt loss with {rnnt_loss} rnnt : {fast_loss}")
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r, s_range=r,
) )
# (B, T, r, C) # (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm logits = pruned_am + pruned_lm
...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruning loss with range {r} : {pruned_loss}") print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols, # Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will # at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions. # raise errors (like, nan or inf loss) in our previous versions.
...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}") print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print( print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
S0 = 2 S0 = 2
if modified: if rnnt_type == "regular":
S0 = 1 S0 = 1
for r in range(S0, S + 2): for r in range(S0, S + 2):
...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruned loss with range {r} : {pruned_loss}") print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment