gather.py 3.82 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import time
import itertools

rusty1s's avatar
rusty1s committed
4
import argparse
rusty1s's avatar
rusty1s committed
5
6
7
8
9
import torch
from scipy.io import loadmat

from torch_scatter import gather_coo, gather_csr

rusty1s's avatar
rusty1s committed
10
from scatter_segment import iters, sizes
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
from scatter_segment import short_rows, long_rows, download, bold


@torch.no_grad()
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
18
19
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
20
21
22
23
    dim_size = rowptr.size(0) - 1

    for size in sizes[1:]:
        try:
rusty1s's avatar
rusty1s committed
24
            x = torch.randn((dim_size, size), device=args.device)
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
34
35
36
            x = x.squeeze(-1) if size == 1 else x

            out1 = x.index_select(0, row)
            out2 = gather_coo(x, row)
            out3 = gather_csr(x, rowptr)

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
        except RuntimeError:
            torch.cuda.empty_cache()


rusty1s's avatar
rusty1s committed
37
38
39
40
def time_func(func, x):
    try:
        torch.cuda.synchronize()
        t = time.perf_counter()
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
48
49
50
51

        if not args.with_backward:
            with torch.no_grad():
                for _ in range(iters):
                    func(x)
        else:
            x = x.requires_grad_()
            for _ in range(iters):
                out = func(x)
                torch.autograd.grad(out, x, out, only_inputs=True)

rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
        torch.cuda.synchronize()
        return time.perf_counter() - t
    except RuntimeError:
        torch.cuda.empty_cache()
        return float('inf')


rusty1s's avatar
rusty1s committed
59
60
61
def timing(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
62
63
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
64
65
66
    dim_size = rowptr.size(0) - 1
    avg_row_len = row.size(0) / dim_size

rusty1s's avatar
linting  
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77
    def select(x):
        return x.index_select(0, row)

    def gather(x):
        return x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))

    def gat_coo(x):
        return gather_coo(x, row)

    def gat_csr(x):
        return gather_csr(x, rowptr)
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
80
81
    t1, t2, t3, t4 = [], [], [], []
    for size in sizes:
        try:
rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
            x = torch.randn((dim_size, size), device=args.device)

            t1 += [time_func(select, x)]
            t2 += [time_func(gather, x)]
            t3 += [time_func(gat_coo, x)]
            t4 += [time_func(gat_csr, x)]
rusty1s's avatar
rusty1s committed
88
89
90
91
92

            del x

        except RuntimeError:
            torch.cuda.empty_cache()
rusty1s's avatar
rusty1s committed
93
            for t in (t1, t2, t3, t4):
rusty1s's avatar
rusty1s committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
    print('\t'.join(['       '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('SELECT ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('GAT    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('GAT_COO')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('GAT_CSR')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()


if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
116
    parser = argparse.ArgumentParser()
rusty1s's avatar
rusty1s committed
117
    parser.add_argument('--with_backward', action='store_true')
rusty1s's avatar
rusty1s committed
118
119
120
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()

rusty1s's avatar
rusty1s committed
121
    for _ in range(10):  # Warmup.
rusty1s's avatar
rusty1s committed
122
        torch.randn(100, 100, device=args.device).sum()
rusty1s's avatar
rusty1s committed
123
124
125
126
    for dataset in itertools.chain(short_rows, long_rows):
        download(dataset)
        correctness(dataset)
        timing(dataset)