test_degree_padding2.py 2.31 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import degree

devices = [torch.device('cuda')]


@pytest.mark.parametrize('device', devices)
def test_padded_index_select(device):
    dataset = Planetoid('/tmp/Planetoid', name='PubMed')
    data = dataset[0]
    row, col = data.edge_index.to(device)

    row = torch.arange(data.num_nodes).view(-1, 1).repeat(1, 4).view(-1)
    col = torch.randint(0, data.num_nodes, (row.size(0), ))
    row, col = row.to(device), col.to(device)

    adj = SparseTensor(row=row, col=col)
    rowcount = adj.storage.rowcount().to(device)
    rowptr = adj.storage.rowptr().to(device)
    bin_strategy = torch.tensor([[1, 4]]).to(device)
    # bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]],
    #                             device=device)
    perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
    lengths = bin_strategy[:, 1].view(-1).tolist()
    print(lengths)

    deg = degree(row, dtype=torch.long)
    print(deg.size(), deg.min(), deg.float().mean(), deg.max())
    bins = torch.bincount(deg)
    print(bins)
    nonzero = bins.nonzero().flatten()
    print(nonzero)
    print(bins[nonzero])

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for dim in [32, 64, 128, 256, 512, 1024]:
        print(f'--- Dim: {dim} ---')
        x = torch.randn(adj.size(0), dim).to(device)

        for i in range(110):
            if i == 10:
                start.record()
            for perm, length in zip(perms, lengths):
                out1, _ = torch.ops.torch_sparse.padded_index_select(
                    x, rowptr, col, perm, length, torch.tensor(0.))
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))

        for i in range(110):
            if i == 10:
                start.record()
            out2 = x.index_select(0, row)
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))

        for i in range(110):
            if i == 10:
                start.record()
            out3 = x.index_select(0, col)
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))
        print(torch.allclose(out1.view(-1, dim), out3))