gather.py 3.79 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
# flake8: noqa

rusty1s's avatar
rusty1s committed
3
4
5
import time
import itertools

rusty1s's avatar
rusty1s committed
6
import argparse
rusty1s's avatar
rusty1s committed
7
8
9
10
11
import torch
from scipy.io import loadmat

from torch_scatter import gather_coo, gather_csr

rusty1s's avatar
rusty1s committed
12
from scatter_segment import iters, sizes
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
from scatter_segment import short_rows, long_rows, download, bold


@torch.no_grad()
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
20
21
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
22
23
24
25
    dim_size = rowptr.size(0) - 1

    for size in sizes[1:]:
        try:
rusty1s's avatar
rusty1s committed
26
            x = torch.randn((dim_size, size), device=args.device)
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
            x = x.squeeze(-1) if size == 1 else x

            out1 = x.index_select(0, row)
            out2 = gather_coo(x, row)
            out3 = gather_csr(x, rowptr)

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
        except RuntimeError:
            torch.cuda.empty_cache()


rusty1s's avatar
rusty1s committed
39
40
41
42
def time_func(func, x):
    try:
        torch.cuda.synchronize()
        t = time.perf_counter()
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
53

        if not args.with_backward:
            with torch.no_grad():
                for _ in range(iters):
                    func(x)
        else:
            x = x.requires_grad_()
            for _ in range(iters):
                out = func(x)
                torch.autograd.grad(out, x, out, only_inputs=True)

rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
        torch.cuda.synchronize()
        return time.perf_counter() - t
    except RuntimeError:
        torch.cuda.empty_cache()
        return float('inf')


rusty1s's avatar
rusty1s committed
61
62
63
def timing(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
64
65
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
66
67
68
    dim_size = rowptr.size(0) - 1
    avg_row_len = row.size(0) / dim_size

rusty1s's avatar
rusty1s committed
69
70
71
72
73
    select = lambda x: x.index_select(0, row)
    gather = lambda x: x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
    gat_coo = lambda x: gather_coo(x, row)
    gat_csr = lambda x: gather_csr(x, rowptr)

rusty1s's avatar
rusty1s committed
74
75
76
    t1, t2, t3, t4 = [], [], [], []
    for size in sizes:
        try:
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
            x = torch.randn((dim_size, size), device=args.device)

            t1 += [time_func(select, x)]
            t2 += [time_func(gather, x)]
            t3 += [time_func(gat_coo, x)]
            t4 += [time_func(gat_csr, x)]
rusty1s's avatar
rusty1s committed
83
84
85
86
87

            del x

        except RuntimeError:
            torch.cuda.empty_cache()
rusty1s's avatar
rusty1s committed
88
            for t in (t1, t2, t3, t4):
rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
    print('\t'.join(['       '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('SELECT ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('GAT    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('GAT_COO')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('GAT_CSR')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()


if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
111
    parser = argparse.ArgumentParser()
rusty1s's avatar
rusty1s committed
112
    parser.add_argument('--with_backward', action='store_true')
rusty1s's avatar
rusty1s committed
113
114
115
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()

rusty1s's avatar
rusty1s committed
116
    for _ in range(10):  # Warmup.
rusty1s's avatar
rusty1s committed
117
        torch.randn(100, 100, device=args.device).sum()
rusty1s's avatar
rusty1s committed
118
119
120
121
    for dataset in itertools.chain(short_rows, long_rows):
        download(dataset)
        correctness(dataset)
        timing(dataset)