test.py 3.27 KB
Newer Older
anton's avatar
anton committed
1
2
3
import time
import torch

anton's avatar
anton committed
4
from torch_discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
anton's avatar
anton committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


def discounted_cumsum_left_gold(input, gamma):
    assert input.dim() == 2
    assert 0 <= gamma <= 1
    out = []
    last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
    for i in range(input.shape[1]):
        cur_col = input[:, i].unsqueeze(-1)
        last_col = cur_col + gamma * last_col
        out.append(last_col)
    out = torch.cat(out, dim=1)
    return out


def discounted_cumsum_right_gold(input, gamma):
    assert input.dim() == 2
    assert 0 <= gamma <= 1
    out = []
    last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
    for i in reversed(range(input.shape[1])):
        cur_col = input[:, i].unsqueeze(-1)
        last_col = cur_col + gamma * last_col
        out.insert(0, last_col)
    out = torch.cat(out, dim=1)
    return out


def test_left():
    torch.manual_seed(0)
    x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
    gamma = 0.99
    out_gold_32 = discounted_cumsum_left_gold(x, gamma)
    out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
    out_fn = discounted_cumsum_left(x, gamma)
    diff_32 = (out_fn - out_gold_32).abs().max().item()
    diff_64 = (out_fn - out_gold_64).abs().max().item()
    print('left diff_32', diff_32)
    print('left diff_64', diff_64)


def test_right():
    torch.manual_seed(0)
    x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
    gamma = 0.99
    out_gold_32 = discounted_cumsum_right_gold(x, gamma)
    out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
    out_fn = discounted_cumsum_right(x, gamma)
    diff_32 = (out_fn - out_gold_32).abs().max().item()
    diff_64 = (out_fn - out_gold_64).abs().max().item()
    print('right diff_32', diff_32)
    print('right diff_64', diff_64)


def test_grad_left():
    torch.manual_seed(0)
    x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
    x = torch.nn.Parameter(x)
    gamma = 0.99

    out_gold = discounted_cumsum_left_gold(x, gamma)
    out_gold.sum().backward()
    out_gold_grad = x.grad.clone()

    del x.grad

    out_fn = discounted_cumsum_left(x, gamma)
    out_fn.sum().backward()
    out_fn_grad = x.grad.clone()

    diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
    print('left diff_grad', diff_grad)


def test_grad_right():
    torch.manual_seed(0)
    x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
    x = torch.nn.Parameter(x)
    gamma = 0.99

    out_gold = discounted_cumsum_right_gold(x, gamma)
    out_gold.sum().backward()
    out_gold_grad = x.grad.clone()

    del x.grad

    out_fn = discounted_cumsum_right(x, gamma)
    out_fn.sum().backward()
    out_fn_grad = x.grad.clone()

    diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
    print('right diff_grad', diff_grad)


def test_speed(reps=10000):
    torch.manual_seed(0)
    x = torch.randn(10, 100000, dtype=torch.float32).cuda()
    gamma = 0.99
    t1 = time.time()
    for _ in range(reps):
        discounted_cumsum_right(x, gamma)
    t2 = time.time()
    print('sec:', t2-t1)


if __name__ == '__main__':
    test_left()
    test_right()
    test_grad_left()
    test_grad_right()
    #test_speed()