Unverified Commit 6b4d97f1 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #48 from rusty1s/graph_saint

Graph saint
parents a1ae9033 a597b822
#include "rw_cpu.h"
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto L = walk_length + 1;
auto out = torch::full({start.size(0), L}, -1, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
auto cur = start_data[n];
out_data[n * L] = cur;
int64_t row_start, row_end;
for (auto l = 0; l < walk_length; l++) {
row_start = rowptr_data[cur];
row_end = rowptr_data[cur + 1];
cur = col_data[row_start + int64_t(rand_data[n * walk_length + l] *
(row_end - row_start))];
out_data[n * L + l + 1] = cur;
}
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include "saint_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
CHECK_CPU(idx);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_INPUT(idx.dim() == 1);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));
auto idx_data = idx.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto assoc_data = assoc.data_ptr<int64_t>();
std::vector<int64_t> rows, cols, indices;
int64_t v, w, w_new, row_start, row_end;
for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
v = idx_data[v_new];
row_start = rowptr_data[v];
row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
w_new = assoc_data[w];
if (w_new > -1) {
rows.push_back(v_new);
cols.push_back(w_new);
indices.push_back(j);
}
}
}
int64_t length = rows.size();
row = torch::from_blob(rows.data(), {length}, row.options()).clone();
col = torch::from_blob(cols.data(), {length}, row.options()).clone();
idx = torch::from_blob(indices.data(), {length}, row.options()).clone();
return std::make_tuple(row, col, idx);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col);
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include <Python.h>
#include <torch/script.h>
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_cuda(rowptr, col, start, walk_length);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_cpu(rowptr, col, start, walk_length);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::random_walk", &random_walk);
#include <Python.h>
#include <torch/script.h>
#include "cpu/saint_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
if (idx.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return subgraph_cpu(idx, rowptr, row, col);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);
import torch
from torch_sparse.tensor import SparseTensor
def test_saint_subgraph():
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col)
node_idx = torch.tensor([0, 1, 2])
adj, edge_index = adj.saint_subgraph(node_idx)
...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4) ...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis' '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw', '_saint'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
...@@ -54,7 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa ...@@ -54,7 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .saint import saint_subgraph # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -94,7 +97,9 @@ __all__ = [ ...@@ -94,7 +97,9 @@ __all__ = [
'matmul', 'matmul',
'cat', 'cat',
'cat_diag', 'cat_diag',
'random_walk',
'partition', 'partition',
'saint_subgraph',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
import torch
from torch_sparse.tensor import SparseTensor
def random_walk(src: SparseTensor, start: torch.Tensor,
walk_length: int) -> torch.Tensor:
rowptr, col, _ = src.csr()
return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
SparseTensor.random_walk = random_walk
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
) -> Tuple[SparseTensor, torch.Tensor]:
row, col, value = src.coo()
rowptr = src.storage.rowptr()
data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
row, col, edge_index = data
if value is not None:
value = value[edge_index]
out = SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
return out, edge_index
SparseTensor.saint_subgraph = saint_subgraph
...@@ -12,16 +12,24 @@ from torch_sparse.utils import is_scalar ...@@ -12,16 +12,24 @@ from torch_sparse.utils import is_scalar
class SparseTensor(object): class SparseTensor(object):
storage: SparseStorage storage: SparseStorage
def __init__(self, row: Optional[torch.Tensor] = None, def __init__(self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[int, int]] = None, sparse_sizes: Optional[Tuple[int, int]] = None,
is_sorted: bool = False): is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col, self.storage = SparseStorage(
value=value, sparse_sizes=sparse_sizes, row=row,
rowcount=None, colptr=None, colcount=None, rowptr=rowptr,
csr2csc=None, csc2csr=None, col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted) is_sorted=is_sorted)
@classmethod @classmethod
...@@ -45,12 +53,17 @@ class SparseTensor(object): ...@@ -45,12 +53,17 @@ class SparseTensor(object):
if has_value: if has_value:
value = mat[row, col] value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True) is_sorted=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor, def from_torch_sparse_coo_tensor(self,
mat: torch.Tensor,
has_value: bool = True): has_value: bool = True):
mat = mat.coalesce() mat = mat.coalesce()
index = mat._indices() index = mat._indices()
...@@ -60,13 +73,20 @@ class SparseTensor(object): ...@@ -60,13 +73,20 @@ class SparseTensor(object):
if has_value: if has_value:
value = mat._values() value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True) is_sorted=True)
@classmethod @classmethod
def eye(self, M: int, N: Optional[int] = None, def eye(self,
options: Optional[torch.Tensor] = None, has_value: bool = True, M: int,
N: Optional[int] = None,
options: Optional[torch.Tensor] = None,
has_value: bool = True,
fill_cache: bool = False): fill_cache: bool = False):
N = M if N is None else N N = M if N is None else N
...@@ -84,8 +104,8 @@ class SparseTensor(object): ...@@ -84,8 +104,8 @@ class SparseTensor(object):
value: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None
if has_value: if has_value:
if options is not None: if options is not None:
value = torch.ones(row.numel(), dtype=options.dtype, value = torch.ones(
device=row.device) row.numel(), dtype=options.dtype, device=row.device)
else: else:
value = torch.ones(row.numel(), device=row.device) value = torch.ones(row.numel(), device=row.device)
...@@ -108,9 +128,17 @@ class SparseTensor(object): ...@@ -108,9 +128,17 @@ class SparseTensor(object):
csr2csc = csc2csr = row csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage( storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N), row=row,
rowcount=rowcount, colptr=colptr, colcount=colcount, rowptr=rowptr,
csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) col=col,
value=value,
sparse_sizes=(M, N),
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
self = SparseTensor.__new__(SparseTensor) self = SparseTensor.__new__(SparseTensor)
self.storage = storage self.storage = storage
...@@ -153,12 +181,14 @@ class SparseTensor(object): ...@@ -153,12 +181,14 @@ class SparseTensor(object):
def has_value(self) -> bool: def has_value(self) -> bool:
return self.storage.has_value() return self.storage.has_value()
def set_value_(self, value: Optional[torch.Tensor], def set_value_(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
self.storage.set_value_(value, layout) self.storage.set_value_(value, layout)
return self return self
def set_value(self, value: Optional[torch.Tensor], def set_value(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout)) return self.from_storage(self.storage.set_value(value, layout))
...@@ -187,23 +217,31 @@ class SparseTensor(object): ...@@ -187,23 +217,31 @@ class SparseTensor(object):
# Utility functions ####################################################### # Utility functions #######################################################
def fill_value_(self, fill_value: float, def fill_value_(self,
fill_value: float,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if options is not None: if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype, value = torch.full((self.nnz(), ),
fill_value,
dtype=options.dtype,
device=self.device()) device=self.device())
else: else:
value = torch.full((self.nnz(), ), fill_value, value = torch.full((self.nnz(), ),
fill_value,
device=self.device()) device=self.device())
return self.set_value_(value, layout='coo') return self.set_value_(value, layout='coo')
def fill_value(self, fill_value: float, def fill_value(self,
fill_value: float,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if options is not None: if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype, value = torch.full((self.nnz(), ),
fill_value,
dtype=options.dtype,
device=self.device()) device=self.device())
else: else:
value = torch.full((self.nnz(), ), fill_value, value = torch.full((self.nnz(), ),
fill_value,
device=self.device()) device=self.device())
return self.set_value(value, layout='coo') return self.set_value(value, layout='coo')
...@@ -270,8 +308,13 @@ class SparseTensor(object): ...@@ -270,8 +308,13 @@ class SparseTensor(object):
N = max(self.size(0), self.size(1)) N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value, out = SparseTensor(
sparse_sizes=(N, N), is_sorted=False) row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(N, N),
is_sorted=False)
out = out.coalesce(reduce) out = out.coalesce(reduce)
return out return out
...@@ -294,7 +337,8 @@ class SparseTensor(object): ...@@ -294,7 +337,8 @@ class SparseTensor(object):
else: else:
return False return False
def requires_grad_(self, requires_grad: bool = True, def requires_grad_(self,
requires_grad: bool = True,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if requires_grad and not self.has_value(): if requires_grad and not self.has_value():
self.fill_value_(1., options=options) self.fill_value_(1., options=options)
...@@ -315,8 +359,8 @@ class SparseTensor(object): ...@@ -315,8 +359,8 @@ class SparseTensor(object):
if value is not None: if value is not None:
return value return value
else: else:
return torch.tensor(0., dtype=torch.float, return torch.tensor(
device=self.storage.col().device) 0., dtype=torch.float, device=self.storage.col().device)
def device(self): def device(self):
return self.storage.col().device return self.storage.col().device
...@@ -324,7 +368,8 @@ class SparseTensor(object): ...@@ -324,7 +368,8 @@ class SparseTensor(object):
def cpu(self): def cpu(self):
return self.device_as(torch.tensor(0.), non_blocking=False) return self.device_as(torch.tensor(0.), non_blocking=False)
def cuda(self, options: Optional[torch.Tensor] = None, def cuda(self,
options: Optional[torch.Tensor] = None,
non_blocking: bool = False): non_blocking: bool = False):
if options is not None: if options is not None:
return self.device_as(options, non_blocking) return self.device_as(options, non_blocking)
...@@ -387,19 +432,19 @@ class SparseTensor(object): ...@@ -387,19 +432,19 @@ class SparseTensor(object):
row, col, value = self.coo() row, col, value = self.coo()
if value is not None: if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype, mat = torch.zeros(
device=self.device()) self.sizes(), dtype=value.dtype, device=self.device())
elif options is not None: elif options is not None:
mat = torch.zeros(self.sizes(), dtype=options.dtype, mat = torch.zeros(
device=self.device()) self.sizes(), dtype=options.dtype, device=self.device())
else: else:
mat = torch.zeros(self.sizes(), device=self.device()) mat = torch.zeros(self.sizes(), device=self.device())
if value is not None: if value is not None:
mat[row, col] = value mat[row, col] = value
else: else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype, mat[row, col] = torch.ones(
device=mat.device) self.nnz(), dtype=mat.dtype, device=mat.device)
return mat return mat
...@@ -409,8 +454,8 @@ class SparseTensor(object): ...@@ -409,8 +454,8 @@ class SparseTensor(object):
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
if value is None: if value is None:
if options is not None: if options is not None:
value = torch.ones(self.nnz(), dtype=options.dtype, value = torch.ones(
device=self.device()) self.nnz(), dtype=options.dtype, device=self.device())
else: else:
value = torch.ones(self.nnz(), device=self.device()) value = torch.ones(self.nnz(), device=self.device())
...@@ -434,7 +479,7 @@ def is_shared(self: SparseTensor) -> bool: ...@@ -434,7 +479,7 @@ def is_shared(self: SparseTensor) -> bool:
def to(self, *args: Optional[List[Any]], def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor: **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
if dtype is not None: if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype)) self = self.type_as(torch.tensor(0., dtype=dtype))
...@@ -515,8 +560,8 @@ SparseTensor.__repr__ = __repr__ ...@@ -515,8 +560,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ########################################################### # Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
scipy.sparse.csc_matrix] csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore @torch.jit.ignore
...@@ -535,16 +580,25 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor: ...@@ -535,16 +580,25 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value = torch.from_numpy(mat.data) value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2] sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(
sparse_sizes=sparse_sizes, rowcount=None, row=row,
colptr=colptr, colcount=None, csr2csc=None, rowptr=rowptr,
csc2csr=None, is_sorted=True) col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
@torch.jit.ignore @torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None, def to_scipy(self: SparseTensor,
layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix: dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
assert self.dim() == 2 assert self.dim() == 2
layout = get_layout(layout) layout = get_layout(layout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment