saint.py 702 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor


rusty1s's avatar
rusty1s committed
7
8
def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
                   ) -> Tuple[SparseTensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    row, col, value = src.coo()
    rowptr = src.storage.rowptr()

    data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
    row, col, edge_index = data

    if value is not None:
        value = value[edge_index]

    out = SparseTensor(
        row=row,
        rowptr=None,
        col=col,
        value=value,
        sparse_sizes=(node_idx.size(0), node_idx.size(0)),
        is_sorted=True)

    return out, edge_index


rusty1s's avatar
rusty1s committed
29
SparseTensor.saint_subgraph = saint_subgraph