Commit b0ed23ef authored by pkufool's avatar pkufool
Browse files

Add constrained rnnt

parent dc35168d
This diff is collapsed.
......@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"):
expected = -m
......@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
......@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
......@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified:
if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
......@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self):
B = 1
......@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
print(
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
print(f"Unpruned rnnt loss with {rnnt_loss} rnnt : {fast_loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
......@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
......@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
......@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
S0 = 2
if modified:
if rnnt_type == "regular":
S0 = 1
for r in range(S0, S + 2):
......@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment