Commit 8b77e547 authored by rusty1s's avatar rusty1s
Browse files

added python wrapper

parent 631df924
...@@ -2,7 +2,7 @@ from itertools import product ...@@ -2,7 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import SparseTensor from torch_sparse import SparseTensor, padded_index_select
from .utils import grad_dtypes, tensor from .utils import grad_dtypes, tensor
...@@ -14,11 +14,9 @@ def test_padded_index_select(dtype, device): ...@@ -14,11 +14,9 @@ def test_padded_index_select(dtype, device):
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3]) row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2]) col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
adj = SparseTensor(row=row, col=col).to(device) adj = SparseTensor(row=row, col=col).to(device)
rowptr, col, _ = adj.csr()
rowcount = adj.storage.rowcount()
binptr = torch.tensor([0, 3, 5], device=device) binptr = torch.tensor([0, 3, 5], device=device)
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr) data = adj.padded_index(binptr)
node_perm, row_perm, col_perm, mask, node_size, edge_size = data node_perm, row_perm, col_perm, mask, node_size, edge_size = data
assert node_perm.tolist() == [2, 3, 0, 1] assert node_perm.tolist() == [2, 3, 0, 1]
...@@ -29,21 +27,21 @@ def test_padded_index_select(dtype, device): ...@@ -29,21 +27,21 @@ def test_padded_index_select(dtype, device):
assert edge_size == [4, 8] assert edge_size == [4, 8]
x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_() x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_()
fill_value = torch.tensor(0., dtype=dtype) x_j = padded_index_select(x, col_perm)
out = torch.ops.torch_sparse.padded_index_select(x, col_perm, fill_value)
assert out.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0] assert x_j.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]
grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device) grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device)
out.backward(grad_out.view(-1, 1)) x_j.backward(grad_out.view(-1, 1))
assert x.grad.flatten().tolist() == [12, 5, 17, 18] assert x.grad.flatten().tolist() == [12, 5, 17, 18]
@pytest.mark.parametrize('device', devices) def test_padded_index_select_runtime():
def test_padded_index_select_runtime(device):
return return
from torch_geometric.datasets import Planetoid from torch_geometric.datasets import Planetoid
device = torch.device('cuda')
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
......
...@@ -58,6 +58,7 @@ from .cat import cat, cat_diag # noqa ...@@ -58,6 +58,7 @@ from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .saint import saint_subgraph # noqa from .saint import saint_subgraph # noqa
from .padding import padded_index, padded_index_select # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -100,6 +101,8 @@ __all__ = [ ...@@ -100,6 +101,8 @@ __all__ = [
'random_walk', 'random_walk',
'partition', 'partition',
'saint_subgraph', 'saint_subgraph',
'padded_index',
'padded_index_select',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
...@@ -5,9 +5,8 @@ from torch_sparse.tensor import SparseTensor ...@@ -5,9 +5,8 @@ from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute from torch_sparse.permute import permute
def partition( def partition(src: SparseTensor, num_parts: int, recursive: bool = False
src: SparseTensor, num_parts: int, recursive: bool = False ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu() rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
...@@ -21,5 +20,4 @@ def partition( ...@@ -21,5 +20,4 @@ def partition(
return out, partptr, perm return out, partptr, perm
SparseTensor.partition = lambda self, num_parts, recursive=False: partition( SparseTensor.partition = partition
self, num_parts, recursive)
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
...@@ -15,13 +15,9 @@ def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor ...@@ -15,13 +15,9 @@ def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
if value is not None: if value is not None:
value = value[edge_index] value = value[edge_index]
out = SparseTensor( out = SparseTensor(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=(node_idx.size(0), node_idx.size(0)),
rowptr=None, is_sorted=True)
col=col,
value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
return out, edge_index return out, edge_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment