Commit fff381c5 authored by rusty1s's avatar rusty1s
Browse files

added saint extract_adj method

parent 92b1e639
#include "saint_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
CHECK_CPU(idx);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_INPUT(idx.dim() == 1);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));
auto idx_data = idx.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto assoc_data = assoc.data_ptr<int64_t>();
std::vector<int64_t> rows, cols, indices;
int64_t v, w, w_new, row_start, row_end;
for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
v = idx_data[v_new];
row_start = rowptr_data[v];
row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
w_new = assoc_data[w];
if (w_new > -1) {
rows.push_back(v_new);
cols.push_back(w_new);
indices.push_back(j);
}
}
}
int64_t length = rows.size();
row = torch::from_blob(rows.data(), {length}, row.options()).clone();
col = torch::from_blob(cols.data(), {length}, row.options()).clone();
idx = torch::from_blob(indices.data(), {length}, row.options()).clone();
return std::make_tuple(row, col, idx);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col);
#include <Python.h>
#include <torch/script.h>
#include "cpu/saint_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
if (idx.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return subgraph_cpu(idx, rowptr, row, col);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.saint import subgraph
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_subgraph(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device)
node_idx = torch.tensor([0, 1, 2])
adj, edge_index = subgraph(adj, node_idx)
@pytest.mark.parametrize('device', devices)
def test_sample_node(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
......
......@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
try:
for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw'
'_rw', '_saint'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
......
......@@ -6,6 +6,28 @@ from torch_scatter import scatter_add
from torch_sparse.tensor import SparseTensor
def subgraph(src: SparseTensor,
node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
row, col, value = src.coo()
rowptr = src.storage.rowptr()
data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
row, col, edge_index = data
if value is not None:
value = value[edge_index]
out = SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
return out, edge_index
def sample_node(src: SparseTensor,
num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]:
row, col, _ = src.coo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment