Commit 8b77e547 authored by rusty1s's avatar rusty1s
Browse files

added python wrapper

parent 631df924
......@@ -2,7 +2,7 @@ from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from torch_sparse import SparseTensor, padded_index_select
from .utils import grad_dtypes, tensor
......@@ -14,11 +14,9 @@ def test_padded_index_select(dtype, device):
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
adj = SparseTensor(row=row, col=col).to(device)
rowptr, col, _ = adj.csr()
rowcount = adj.storage.rowcount()
binptr = torch.tensor([0, 3, 5], device=device)
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
data = adj.padded_index(binptr)
node_perm, row_perm, col_perm, mask, node_size, edge_size = data
assert node_perm.tolist() == [2, 3, 0, 1]
......@@ -29,21 +27,21 @@ def test_padded_index_select(dtype, device):
assert edge_size == [4, 8]
x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_()
fill_value = torch.tensor(0., dtype=dtype)
out = torch.ops.torch_sparse.padded_index_select(x, col_perm, fill_value)
x_j = padded_index_select(x, col_perm)
assert out.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]
assert x_j.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]
grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device)
out.backward(grad_out.view(-1, 1))
x_j.backward(grad_out.view(-1, 1))
assert x.grad.flatten().tolist() == [12, 5, 17, 18]
@pytest.mark.parametrize('device', devices)
def test_padded_index_select_runtime(device):
def test_padded_index_select_runtime():
return
from torch_geometric.datasets import Planetoid
device = torch.device('cuda')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
......
......@@ -58,6 +58,7 @@ from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa
from .saint import saint_subgraph # noqa
from .padding import padded_index, padded_index_select # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa
......@@ -100,6 +101,8 @@ __all__ = [
'random_walk',
'partition',
'saint_subgraph',
'padded_index',
'padded_index_select',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy',
......
......@@ -5,9 +5,8 @@ from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def partition(
src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
def partition(src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
......@@ -21,5 +20,4 @@ def partition(
return out, partptr, perm
SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
self, num_parts, recursive)
SparseTensor.partition = partition
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
......@@ -15,13 +15,9 @@ def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
if value is not None:
value = value[edge_index]
out = SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
return out, edge_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment