gather.py 4.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import time
import itertools

import torch
from scipy.io import loadmat

from torch_scatter import gather_coo, gather_csr

from scatter_segment import iters, device, sizes
from scatter_segment import short_rows, long_rows, download, bold


@torch.no_grad()
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
    dim_size = rowptr.size(0) - 1

    for size in sizes[1:]:
        try:
            x = torch.randn((dim_size, size), device=device)
            x = x.squeeze(-1) if size == 1 else x

            out1 = x.index_select(0, row)
            out2 = gather_coo(x, row)
            out3 = gather_csr(x, rowptr)

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
        except RuntimeError:
            torch.cuda.empty_cache()


@torch.no_grad()
def timing(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
    dim_size = rowptr.size(0) - 1
    avg_row_len = row.size(0) / dim_size

    t1, t2, t3, t4 = [], [], [], []
    for size in sizes:
        try:
            x = torch.randn((dim_size, size), device=device)
            row_expand = row.view(-1, 1).expand(-1, x.size(-1))
            x = x.squeeze(-1) if size == 1 else x
            row_expand = row_expand.squeeze(-1) if size == 1 else row_expand

            try:
                torch.cuda.synchronize()
                t = time.perf_counter()
                for _ in range(iters):
                    out = x.index_select(0, row)
                    del out
                torch.cuda.synchronize()
                t1.append(time.perf_counter() - t)
            except RuntimeError:
                torch.cuda.empty_cache()
                t1.append(float('inf'))

            try:
                torch.cuda.synchronize()
                t = time.perf_counter()
                for _ in range(iters):
                    out = x.gather(0, row_expand)
                    del out
                torch.cuda.synchronize()
                t2.append(time.perf_counter() - t)
            except RuntimeError:
                torch.cuda.empty_cache()
                t2.append(float('inf'))

            try:
                torch.cuda.synchronize()
                t = time.perf_counter()
                for _ in range(iters):
                    out = gather_coo(x, row)
                    del out
                torch.cuda.synchronize()
                t3.append(time.perf_counter() - t)
            except RuntimeError:
                torch.cuda.empty_cache()
                t3.append(float('inf'))

            try:
                torch.cuda.synchronize()
                t = time.perf_counter()
                for _ in range(iters):
                    out = gather_csr(x, rowptr)
                    del out
                torch.cuda.synchronize()
                t4.append(time.perf_counter() - t)
            except RuntimeError:
                torch.cuda.empty_cache()
                t4.append(float('inf'))

            del x

        except RuntimeError:
            torch.cuda.empty_cache()
            for t in (t1, t2, t3):
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
    print('\t'.join(['       '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('SELECT ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('GAT    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('GAT_COO')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('GAT_CSR')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()


if __name__ == '__main__':
    for _ in range(10):  # Warmup.
        torch.randn(100, 100, device=device).sum()
    for dataset in itertools.chain(short_rows, long_rows):
        download(dataset)
        correctness(dataset)
        timing(dataset)