saint.py 2.3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
from typing import Tuple

import torch
import numpy as np
from torch_scatter import scatter_add
from torch_sparse.tensor import SparseTensor


rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def subgraph(src: SparseTensor,
             node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
    row, col, value = src.coo()
    rowptr = src.storage.rowptr()

    data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
    row, col, edge_index = data

    if value is not None:
        value = value[edge_index]

    out = SparseTensor(
        row=row,
        rowptr=None,
        col=col,
        value=value,
        sparse_sizes=(node_idx.size(0), node_idx.size(0)),
        is_sorted=True)

    return out, edge_index


rusty1s's avatar
rusty1s committed
31
def sample_node(src: SparseTensor,
rusty1s's avatar
typo  
rusty1s committed
32
                num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    row, col, _ = src.coo()

    inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
    inv_in_deg[inv_in_deg == float('inf')] = 0

    prob = inv_in_deg[col]
    prob.mul_(prob)

    prob = scatter_add(prob, row, dim=0, dim_size=src.size(0))
    prob.div_(prob.sum())

    node_idx = prob.multinomial(num_nodes, replacement=True).unique()

    return src.permute(node_idx), node_idx


def sample_edge(src: SparseTensor,
rusty1s's avatar
typo  
rusty1s committed
50
                num_edges: int) -> Tuple[SparseTensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    row, col, _ = src.coo()

    inv_out_deg = src.storage.rowcount().to(torch.float).pow_(-1)
    inv_out_deg[inv_out_deg == float('inf')] = 0
    inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
    inv_in_deg[inv_in_deg == float('inf')] = 0

    prob = inv_out_deg[row] + inv_in_deg[col]
    prob.div_(prob.sum())

    edge_idx = prob.multinomial(num_edges, replacement=True)
    node_idx = col[edge_idx].unique()

    return src.permute(node_idx), node_idx


def sample_rw(src: SparseTensor, num_root_nodes: int,
rusty1s's avatar
typo  
rusty1s committed
69
              walk_length: int) -> Tuple[SparseTensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
    rowptr, col, _ = src.csr()

rusty1s's avatar
rusty1s committed
73
    start = np.random.choice(src.size(0), size=num_root_nodes, replace=False)
rusty1s's avatar
typo  
rusty1s committed
74
    start = torch.from_numpy(start).to(src.device(), torch.long)
rusty1s's avatar
rusty1s committed
75

rusty1s's avatar
rusty1s committed
76
77
78
    out = torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)

    node_idx = out.flatten().unique()
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
    return src.permute(node_idx), node_idx
rusty1s's avatar
rusty1s committed
81
82
83
84
85


SparseTensor.sample_node = sample_node
SparseTensor.sample_edge = sample_edge
SparseTensor.sample_rw = sample_rw