test_segment.py 3.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import time
rusty1s's avatar
rusty1s committed
2
3
4
5
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
6
from torch_scatter import segment_add, scatter_add
rusty1s's avatar
rusty1s committed
7
from torch_scatter.segment import segment_add_csr, segment_add_coo
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17

from .utils import tensor

dtypes = [torch.float]
devices = [torch.device('cuda')]


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
    src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
rusty1s's avatar
rusty1s committed
18
    index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
rusty1s's avatar
rusty1s committed
19
    out = segment_add(src, index, dim=0)
rusty1s's avatar
rusty1s committed
20
    # print('Thrust', out)
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
24

@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward2(dtype, device):
rusty1s's avatar
rusty1s committed
25
26
27
28
    src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
                 device)
    indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
    # indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
    out = segment_add_csr(src, indptr)
rusty1s's avatar
rusty1s committed
31
    # print('CSR', out)
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
34
35
    # index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
    # out = segment_add_coo(src, index)
    # print('COO', out)
rusty1s's avatar
rusty1s committed
36
37
38
39
40


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
    from torch_geometric.datasets import Planetoid, Reddit  # noqa
rusty1s's avatar
rusty1s committed
41
    # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
rusty1s's avatar
rusty1s committed
42
43
    data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
    # data = Reddit('/tmp/Reddit')[0].to(device)
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54
55
56
    row, col = data.edge_index
    x = torch.randn(data.num_edges, device=device)
    print(row.size(0) / data.num_nodes)

    # Warmup
    for _ in range(10):
        torch.randn(100, 100, device=device).sum()

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
    torch.cuda.synchronize()
rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
63
64
    print('Scatter Row', time.perf_counter() - t)

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        scatter_add(x, col, dim=0, dim_size=data.num_nodes)
    torch.cuda.synchronize()
    print('Scatter Col', time.perf_counter() - t)
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71

    torch.cuda.synchronize()

    t = time.perf_counter()
    for _ in range(100):
        out2 = segment_add(x, row, dim=0, dim_size=data.num_nodes)
    torch.cuda.synchronize()
rusty1s's avatar
rusty1s committed
72
    print('Thrust', time.perf_counter() - t)
rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
79
80
81
82

    assert torch.allclose(out1, out2, atol=1e-2)

    rowcount = segment_add(torch.ones_like(row), row)
    rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
    torch.cuda.synchronize()

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
rusty1s's avatar
rusty1s committed
83
        out3 = segment_add_csr(x, rowptr)
rusty1s's avatar
rusty1s committed
84
    torch.cuda.synchronize()
rusty1s's avatar
rusty1s committed
85
    print('CSR', time.perf_counter() - t)
rusty1s's avatar
rusty1s committed
86
87

    assert torch.allclose(out1, out3, atol=1e-2)
rusty1s's avatar
rusty1s committed
88
89
90
91
92
93
94
95
96

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
    torch.cuda.synchronize()
    print('COO', time.perf_counter() - t)

    assert torch.allclose(out1, out4, atol=1e-2)
rusty1s's avatar
rusty1s committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    x = torch.randn((data.num_edges, 1024), device=device)

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
    torch.cuda.synchronize()
    print('Scatter Row + Dim', time.perf_counter() - t)

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        scatter_add(x, col, dim=0, dim_size=data.num_nodes)
    torch.cuda.synchronize()
    print('Scatter Col + Dim', time.perf_counter() - t)

    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(100):
        out6 = segment_add_csr(x, rowptr)
    torch.cuda.synchronize()
    print('CSR + Dim', time.perf_counter() - t)

    assert torch.allclose(out5, out6, atol=1e-2)