test_cpu.py 2.63 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import warpctc_pytorch as warp_ctc
import pytest


def test_simple():
    probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
    grads = torch.zeros(probs.size())
    labels = torch.IntTensor([1, 2])
    label_sizes = torch.IntTensor([2])
    sizes = torch.IntTensor(probs.size(1)).fill_(probs.size(0))
    minibatch_size = probs.size(1)
    costs = torch.zeros(minibatch_size)
    warp_ctc.cpu_ctc(probs,
                     grads,
                     labels,
                     label_sizes,
                     sizes,
                     minibatch_size,
                     costs,
                     0)
    print('CPU_cost: %f' % costs.sum())


@pytest.mark.parametrize("multiplier", [1.0, 200.0])
def test_medium(multiplier):
    probs = torch.FloatTensor([
        [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
        [[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
    ]).contiguous() * multiplier

    grads = torch.zeros(probs.size())
    labels = torch.IntTensor([1, 2, 1, 2])
    label_sizes = torch.IntTensor([2, 2])
    sizes = torch.IntTensor([2, 2])
    minibatch_size = probs.size(1)
    costs = torch.zeros(minibatch_size)
    warp_ctc.cpu_ctc(probs,
                     grads,
                     labels,
                     label_sizes,
                     sizes,
                     minibatch_size,
                     costs,
                     0)
    print('CPU_cost: %f' % costs.sum())


def test_empty_label():
    probs = torch.FloatTensor([
        [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
        [[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
    ]).contiguous()

    grads = torch.zeros(probs.size())
    labels = torch.IntTensor([1, 2])
    label_sizes = torch.IntTensor([2, 0])
    sizes = torch.IntTensor([2, 2])
    minibatch_size = probs.size(1)
    costs = torch.zeros(minibatch_size)
    warp_ctc.cpu_ctc(probs,
                     grads,
                     labels,
                     label_sizes,
                     sizes,
                     minibatch_size,
                     costs,
                     0)
    print('CPU_cost: %f' % costs.sum())


def test_CTCLoss():
    probs = torch.FloatTensor([[
        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
    ]]).transpose(0, 1).contiguous()
    labels = torch.IntTensor([1, 2])
    label_sizes = torch.IntTensor([2])
    probs_sizes = torch.IntTensor([2])
    probs.requires_grad_(True)

    ctc_loss = warp_ctc.CTCLoss()
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    cost.backward()


if __name__ == '__main__':
    pytest.main([__file__])