gather.py 4.16 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import time
import itertools

rusty1s's avatar
rusty1s committed
4
import argparse
rusty1s's avatar
rusty1s committed
5
6
7
8
9
import torch
from scipy.io import loadmat

from torch_scatter import gather_coo, gather_csr

rusty1s's avatar
rusty1s committed
10
from scatter_segment import iters, sizes
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
from scatter_segment import short_rows, long_rows, download, bold


@torch.no_grad()
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
18
19
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
20
21
22
23
    dim_size = rowptr.size(0) - 1

    for size in sizes[1:]:
        try:
rusty1s's avatar
rusty1s committed
24
            x = torch.randn((dim_size, size), device=args.device)
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
            x = x.squeeze(-1) if size == 1 else x

            out1 = x.index_select(0, row)
            out2 = gather_coo(x, row)
            out3 = gather_csr(x, rowptr)

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
rusty1s's avatar
rusty1s committed
33
34
35
        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
rusty1s's avatar
rusty1s committed
36
37
38
            torch.cuda.empty_cache()


rusty1s's avatar
rusty1s committed
39
40
def time_func(func, x):
    try:
rusty1s's avatar
rusty1s committed
41
42
        if torch.cuda.is_available():
            torch.cuda.synchronize()
rusty1s's avatar
rusty1s committed
43
        t = time.perf_counter()
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54

        if not args.with_backward:
            with torch.no_grad():
                for _ in range(iters):
                    func(x)
        else:
            x = x.requires_grad_()
            for _ in range(iters):
                out = func(x)
                torch.autograd.grad(out, x, out, only_inputs=True)

rusty1s's avatar
rusty1s committed
55
56
        if torch.cuda.is_available():
            torch.cuda.synchronize()
rusty1s's avatar
rusty1s committed
57
        return time.perf_counter() - t
rusty1s's avatar
rusty1s committed
58
59
60
    except RuntimeError as e:
        if 'out of memory' not in str(e):
            raise RuntimeError(e)
rusty1s's avatar
rusty1s committed
61
62
63
64
        torch.cuda.empty_cache()
        return float('inf')


rusty1s's avatar
rusty1s committed
65
66
67
def timing(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
68
69
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
70
71
72
    dim_size = rowptr.size(0) - 1
    avg_row_len = row.size(0) / dim_size

rusty1s's avatar
linting  
rusty1s committed
73
74
75
76
77
78
79
80
81
82
83
    def select(x):
        return x.index_select(0, row)

    def gather(x):
        return x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))

    def gat_coo(x):
        return gather_coo(x, row)

    def gat_csr(x):
        return gather_csr(x, rowptr)
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
86
87
    t1, t2, t3, t4 = [], [], [], []
    for size in sizes:
        try:
rusty1s's avatar
rusty1s committed
88
89
90
91
92
93
            x = torch.randn((dim_size, size), device=args.device)

            t1 += [time_func(select, x)]
            t2 += [time_func(gather, x)]
            t3 += [time_func(gat_coo, x)]
            t4 += [time_func(gat_csr, x)]
rusty1s's avatar
rusty1s committed
94
95
96

            del x

rusty1s's avatar
rusty1s committed
97
98
99
        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
rusty1s's avatar
rusty1s committed
100
            torch.cuda.empty_cache()
rusty1s's avatar
rusty1s committed
101
            for t in (t1, t2, t3, t4):
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
    print('\t'.join(['       '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('SELECT ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('GAT    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('GAT_COO')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('GAT_CSR')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()


if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
124
    parser = argparse.ArgumentParser()
rusty1s's avatar
rusty1s committed
125
    parser.add_argument('--with_backward', action='store_true')
rusty1s's avatar
rusty1s committed
126
127
128
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()

rusty1s's avatar
rusty1s committed
129
    for _ in range(10):  # Warmup.
rusty1s's avatar
rusty1s committed
130
        torch.randn(100, 100, device=args.device).sum()
rusty1s's avatar
rusty1s committed
131
132
133
134
    for dataset in itertools.chain(short_rows, long_rows):
        download(dataset)
        correctness(dataset)
        timing(dataset)