Commit 77125678 authored by Daniel Povey's avatar Daniel Povey
Browse files

Add more test

parent 2c3a7e1d
...@@ -40,8 +40,15 @@ def test_rnnt_logprobs_basic(): ...@@ -40,8 +40,15 @@ def test_rnnt_logprobs_basic():
m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None) m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None)
print("m2 = ", m2) print("m2 = ", m2)
device = torch.device('cuda')
m3 = rnnt_loss_simple(lm.to(device), am.to(device), symbols.to(device), termination_symbol, None)
print("m3 = ", m2)
assert torch.allclose(m, m2) assert torch.allclose(m, m2)
assert torch.allclose(m, m3.to('cpu'))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment