rnnt_test.py 1.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

import random
import torch
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion, get_rnnt_logprobs, rnnt_loss_simple


def test_rnnt_logprobs_basic():
    print("Running test_rnnt_logprobs_basic()")

    B = 1
    S = 3
    T = 4
    C = 3

    # lm: [B][S+1][C]
    lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
    # am: [B][T][C]
    am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)

#    lm[:] = 0.0
#    am[:] = 0.0

    termination_symbol = 2
    symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)

    px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)

    assert px.shape == (B, S, T+1)
    assert py.shape == (B, S+1, T)
    assert symbols.shape == (B, S)
    print("px = ", px)
    print("py = ", py)
    m = mutual_information_recursion(px, py)
    print("m = ", m)


    # should be invariant to adding a constant for any frame.
    lm += torch.randn(B, S+1, 1)
    am += torch.randn(B, T, 1)

    m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None)
    print("m2 = ", m2)
    assert torch.allclose(m, m2)



if __name__ == "__main__":
    #torch.set_printoptions(edgeitems=30)
    test_rnnt_logprobs_basic()