rnnt_test.py 2.76 KB
Newer Older
1
2
3

import random
import torch
4
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion, get_rnnt_logprobs, rnnt_loss_simple, rnnt_loss_aux
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42


def test_rnnt_logprobs_basic():
    print("Running test_rnnt_logprobs_basic()")

    B = 1
    S = 3
    T = 4
    C = 3

    # lm: [B][S+1][C]
    lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
    # am: [B][T][C]
    am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)

#    lm[:] = 0.0
#    am[:] = 0.0

    termination_symbol = 2
    symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)

    px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)

    assert px.shape == (B, S, T+1)
    assert py.shape == (B, S+1, T)
    assert symbols.shape == (B, S)
    print("px = ", px)
    print("py = ", py)
    m = mutual_information_recursion(px, py)
    print("m = ", m)


    # should be invariant to adding a constant for any frame.
    lm += torch.randn(B, S+1, 1)
    am += torch.randn(B, T, 1)

    m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None)
    print("m2 = ", m2)
Daniel Povey's avatar
Daniel Povey committed
43
44
45

    device = torch.device('cuda')
    m3 = rnnt_loss_simple(lm.to(device), am.to(device), symbols.to(device), termination_symbol, None)
46
47
48
49
50
51
    print("m3 = ", m3)

    device = torch.device('cuda')
    m4 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
                       lm_only_scale=0.0, am_only_scale=0.0, boundary=None)
    print("m4 = ", m4)
Daniel Povey's avatar
Daniel Povey committed
52

53
54
    assert torch.allclose(m, m2)

Daniel Povey's avatar
Daniel Povey committed
55
56
    assert torch.allclose(m, m3.to('cpu'))

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    assert torch.allclose(m, m4.to('cpu'))

def test_rnnt_logprobs_aux():

    print("Running test_rnnt_logprobs_aux()")

    B = 1
    S = 3
    T = 4
    C = 3

    # lm: [B][S+1][C]
    lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
    # am: [B][T][C]
    am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)

    termination_symbol = 2
    symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)


    device = torch.device('cuda')
    m1 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
                       lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
    print("m1 = ", m1)


    # should be invariant to adding a constant for any frame.
    lm += torch.randn(B, S+1, 1)
    am += torch.randn(B, T, 1)

    m2 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
                       lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
    print("m2 = ", m2)


    assert torch.allclose(m1, m2)

94
95
96
97


if __name__ == "__main__":
    #torch.set_printoptions(edgeitems=30)
98
    test_rnnt_logprobs_aux()
99
    test_rnnt_logprobs_basic()