scatter_segment.py 6.42 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
# flake8: noqa

rusty1s's avatar
rusty1s committed
3
4
5
6
import time
import os.path as osp
import itertools

rusty1s's avatar
rusty1s committed
7
import argparse
rusty1s's avatar
rusty1s committed
8
9
import wget
import torch
rusty1s's avatar
rusty1s committed
10
from scipy.io import loadmat
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
import torch_scatter
rusty1s's avatar
rusty1s committed
13
14
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22

iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512]

short_rows = [
    ('DIMACS10', 'citationCiteseer'),
    ('SNAP', 'web-Stanford'),
]
rusty1s's avatar
rusty1s committed
23
24
25
26
long_rows = [
    ('Janna', 'StocF-1465'),
    ('GHS_psdef', 'ldoor'),
]
rusty1s's avatar
rusty1s committed
27
28


rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
def download(dataset):
    url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
    for group, name in itertools.chain(long_rows, short_rows):
        if not osp.exists(f'{name}.mat'):
            print(f'Downloading {group}/{name}:')
            wget.download(url.format(group, name))
            print('')
rusty1s's avatar
rusty1s committed
36
37
38
39
40
41
42
43
44
45


def bold(text, flag=True):
    return f'\033[1m{text}\033[0m' if flag else text


@torch.no_grad()
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
46
47
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
48
49
50
51
    dim_size = rowptr.size(0) - 1

    for size in sizes:
        try:
rusty1s's avatar
rusty1s committed
52
            x = torch.randn((row.size(0), size), device=args.device)
rusty1s's avatar
rusty1s committed
53
            x = x.squeeze(-1) if size == 1 else x
rusty1s's avatar
rusty1s committed
54
55

            out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
rusty1s's avatar
rusty1s committed
56
57
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
            out3 = segment_csr(x, rowptr, reduce='add')
rusty1s's avatar
rusty1s committed
58
59
60

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68

            out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
            out3 = segment_csr(x, rowptr, reduce='mean')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

rusty1s's avatar
rusty1s committed
69
            x = x.abs_().mul_(-1)
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
73
            out1, arg_out1 = scatter_min(x, row, 0, torch.zeros_like(out1))
            out2, arg_out2 = segment_coo(x, row, reduce='min')
            out3, arg_out3 = segment_csr(x, rowptr, reduce='min')
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
rusty1s committed
75
76
            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
            x = x.abs_()
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
81
82
            out1, arg_out1 = scatter_max(x, row, 0, torch.zeros_like(out1))
            out2, arg_out2 = segment_coo(x, row, reduce='max')
            out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
85
            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
rusty1s committed
87
88
89
90
        except RuntimeError:
            torch.cuda.empty_cache()


rusty1s's avatar
rusty1s committed
91
@torch.no_grad()
rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
99
100
101
102
103
104
def time_func(func, x):
    try:
        torch.cuda.synchronize()
        t = time.perf_counter()
        for _ in range(iters):
            func(x)
        torch.cuda.synchronize()
        return time.perf_counter() - t
    except RuntimeError:
        torch.cuda.empty_cache()
        return float('inf')


rusty1s's avatar
rusty1s committed
105
106
107
108
@torch.no_grad()
def timing(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rusty1s's avatar
rusty1s committed
109
110
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
rusty1s's avatar
rusty1s committed
111
112
113
114
    row_perm = row[torch.randperm(row.size(0))]
    dim_size = rowptr.size(0) - 1
    avg_row_len = row.size(0) / dim_size

rusty1s's avatar
rusty1s committed
115
116
117
118
119
120
121
122
123
    sca_row = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
        x, row, dim=0, dim_size=dim_size)
    sca_col = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
        x, row_perm, dim=0, dim_size=dim_size)
    seg_coo = lambda x: segment_coo(x, row, reduce=args.reduce)
    seg_csr = lambda x: segment_csr(x, rowptr, reduce=args.reduce)
    dense1 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-2)
    dense2 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-1)

rusty1s's avatar
rusty1s committed
124
    t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
rusty1s's avatar
rusty1s committed
125

rusty1s's avatar
rusty1s committed
126
127
    for size in sizes:
        try:
rusty1s's avatar
rusty1s committed
128
            x = torch.randn((row.size(0), size), device=args.device)
rusty1s's avatar
rusty1s committed
129
            x = x.squeeze(-1) if size == 1 else x
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
132
133
134
            t1 += [time_func(sca_row, x)]
            t2 += [time_func(sca_col, x)]
            t3 += [time_func(seg_coo, x)]
            t4 += [time_func(seg_csr, x)]
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141
142
143
144

            del x

        except RuntimeError:
            torch.cuda.empty_cache()
            for t in (t1, t2, t3, t4):
                t.append(float('inf'))

        try:
            x = torch.randn((dim_size, int(avg_row_len + 1), size),
rusty1s's avatar
rusty1s committed
145
                            device=args.device)
rusty1s's avatar
rusty1s committed
146

rusty1s's avatar
rusty1s committed
147
            t5 += [time_func(dense1, x)]
rusty1s's avatar
rusty1s committed
148
            x = x.view(dim_size, size, int(avg_row_len + 1))
rusty1s's avatar
rusty1s committed
149
            t6 += [time_func(dense2, x)]
rusty1s's avatar
rusty1s committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

            del x

        except RuntimeError:
            torch.cuda.empty_cache()
            for t in (t5, t6):
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4, t5, t6])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
rusty1s's avatar
typos  
rusty1s committed
165
    print('\t'.join(['       '] + [f'{size:>5}' for size in sizes]))
rusty1s's avatar
rusty1s committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    print('\t'.join([bold('SCA_ROW')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('SCA_COL')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('SEG_COO')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('SEG_CSR')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print('\t'.join([bold('DENSE1 ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
    print('\t'.join([bold('DENSE2 ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
    print()


rusty1s's avatar
rusty1s committed
181
if __name__ == '__main__':
rusty1s's avatar
rusty1s committed
182
183
184
185
186
187
188
    parser = argparse.ArgumentParser()
    parser.add_argument('--reduce', type=str, required=True,
                        choices=['add', 'mean', 'min', 'max'])
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()
    args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce

rusty1s's avatar
rusty1s committed
189
    for _ in range(10):  # Warmup.
rusty1s's avatar
rusty1s committed
190
        torch.randn(100, 100, device=args.device).sum()
rusty1s's avatar
rusty1s committed
191
192
193
    for dataset in itertools.chain(short_rows, long_rows):
        download(dataset)
        correctness(dataset)
rusty1s's avatar
rusty1s committed
194
        timing(dataset)