test_degree_padding.py 2.91 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pytest
import torch
from torch_sparse import SparseTensor
from torch_geometric.datasets import Planetoid

devices = [torch.device('cuda')]


@pytest.mark.parametrize('device', devices)
def test_bin_assignment(device):
    rowcount = torch.tensor([2, 3, 6, 4, 5, 7, 8, 1], device=device)
    bin_strategy = torch.tensor([[1, 4], [5, 8]], device=device)

    perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
    print()
    print(perms)

    dataset = Planetoid('/tmp/Planetoid', name='PubMed')
    row, col = dataset[0].edge_index
    adj = SparseTensor(row=row, col=col)
    rowcount = adj.storage.rowcount().to(device)
    # bin_strategy = torch.tensor([[1, 7], [8, 12]], device=device)
    bin_strategy = torch.tensor([[1, 4], [5, 13], [14, 22]], device=device)
    bin_count = [4, 13, 22]

    # src = torch.tensor([
    #     [1, 1],
    #     [2, 2],
    #     [3, 3],
    #     [4, 4],
    #     [5, 5],
    #     [6, 6],
    #     [7, 7],
    #     [8, 8],
    # ], dtype=torch.float, device=device)

    # rowptr = torch.tensor([0, 2, 5, 8, 10], device=device)
    # col = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 1], device=device)
    # index = torch.tensor([1, 2, 3], device=device)

    # out, mask = torch.ops.torch_sparse.padded_index_select(
    #     src, rowptr, col, index, 4)
    # print(out)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for i in range(102):
        if i == 2:
            start.record()
        perms = torch.ops.torch_sparse.bin_assignment(rowcount, bin_strategy)
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

    print('-------------')

    x = torch.randn(dataset[0].num_nodes, 512).to(device)
    col = col.to(device)
    for i in range(102):
        if i == 2:
            start.record()
        x = x.index_select(0, col)
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

    x = torch.randn(dataset[0].num_nodes, 512).to(device)
    rowptr = adj.storage.rowptr().to(device)
    col = col.to(device)
    for i in range(102):
        if i == 2:
            start.record()
        torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[0],
                                                   bin_count[0])
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))
    for i in range(102):
        if i == 2:
            start.record()
        torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[1],
                                                   bin_count[1])
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))
    for i in range(102):
        if i == 2:
            start.record()
        torch.ops.torch_sparse.padded_index_select(x, rowptr, col, perms[2],
                                                   bin_count[2])
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))