test_padding.py 3.22 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor

from .utils import grad_dtypes, tensor

devices = [torch.device('cuda')]


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_padded_index_select(dtype, device):
    row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
    col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
    adj = SparseTensor(row=row, col=col).to(device)
    rowptr, col, _ = adj.csr()
    rowcount = adj.storage.rowcount()
    binptr = torch.tensor([0, 3, 5], device=device)

    data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
    node_perm, row_perm, col_perm, mask, node_size, edge_size = data

    assert node_perm.tolist() == [2, 3, 0, 1]
    assert row_perm.tolist() == [2, 2, 3, -1, 0, 0, 0, 0, 1, 1, 1, -1]
    assert col_perm.tolist() == [1, 3, 2, -1, 0, 1, 2, 3, 0, 2, 3, -1]
    assert mask.long().tolist() == [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
    assert node_size == [2, 2]
    assert edge_size == [4, 8]

    x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_()
    fill_value = torch.tensor(0., dtype=dtype)
    out = torch.ops.torch_sparse.padded_index_select(x, col_perm, fill_value)

    assert out.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]

    grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device)
    out.backward(grad_out.view(-1, 1))

    assert x.grad.flatten().tolist() == [12, 5, 17, 18]


@pytest.mark.parametrize('device', devices)
def test_padded_index_select_runtime(device):
    return
    from torch_geometric.datasets import Planetoid
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    dataset = Planetoid('/tmp/Planetoid', name='PubMed')
    data = dataset[0]
    row, col = data.edge_index.to(device)

    adj = SparseTensor(row=row, col=col)
    rowcount = adj.storage.rowcount().to(device)
    rowptr = adj.storage.rowptr().to(device)
    binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)

    x = torch.randn(adj.size(0), 512).to(device)

    data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
    node_perm, row_perm, col_perm, mask, node_sizes, edge_sizes = data

    out = torch.ops.torch_sparse.padded_index_select(x, col_perm,
                                                     torch.tensor(0.))
    outs = out.split(edge_sizes)
    for out, size in zip(outs, node_sizes):
        print(out.view(size, -1, x.size(-1)).shape)

    for i in range(110):
        if i == 10:
            start.record()
        torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
    end.record()
    torch.cuda.synchronize()
    print('padded index', start.elapsed_time(end))

    for i in range(110):
        if i == 10:
            start.record()
        out = torch.ops.torch_sparse.padded_index_select(
            x, col_perm, torch.tensor(0.))
        out.split(edge_sizes)
    end.record()
    torch.cuda.synchronize()
    print('padded index select', start.elapsed_time(end))

    for i in range(110):
        if i == 10:
            start.record()
        x.index_select(0, col)
    end.record()
    torch.cuda.synchronize()
    print('index_select', start.elapsed_time(end))