Commit f59fe649 authored by rusty1s's avatar rusty1s
Browse files

beginning of torch script support

parent c4484dbb
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.storage import SparseStorage
from typing import Dict, Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@torch.jit.script
class Foo:
rowptr: torch.Tensor
col: torch.Tensor
def __init__(self, rowptr: torch.Tensor, col: torch.Tensor):
self.rowptr = rowptr
self.col = col
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseStorage) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.rowptr(), adj.col(), None,
x, 'sum')
return out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def test_jit():
my_cell = MyCell()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x = torch.randn(3, 2)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr = torch.tensor([0, 3, 6, 9])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
adj = SparseStorage(rowptr=rowptr, col=col)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
out = traced_cell(x, adj)
print(out)
# # print(traced_cell.code)
import torch import torch
import torch_scatter import torch_scatter
from .unique import unique # from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
...@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0): ...@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
raise NotImplementedError
row, col = index row, col = index
......
import torch import torch
from torch_sparse.utils import ext
def remove_diag(src, k=0): def remove_diag(src, k=0):
row, col, value = src.coo() row, col, value = src.coo()
...@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0): ...@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row, col, value = src.coo() row, col, value = src.coo()
mask = ext(row.is_cuda).non_diag_mask(row, col, src.size(0), src.size(1), if row.is_cuda:
k) mask = torch.ops.torch_sparse_cuda.non_diag_mask(
row, col, src.size(0), src.size(1), k)
else:
mask = torch.ops.torch_sparse_cpu.non_diag_mask(
row, col, src.size(0), src.size(1), k)
inv_mask = ~mask inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
......
import torch import torch
import scipy.sparse import scipy.sparse
from torch_scatter import scatter_add from torch_scatter import scatter_add
from torch_sparse.utils import ext
ext = None
class SPMM(torch.autograd.Function): class SPMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc, def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce): reduce):
out, arg_out = ext(mat.is_cuda).spmm(rowptr, col, value, mat, reduce) if mat.is_cuda:
out, arg_out = torch.ops.torch_sparse_cuda.spmm(
rowptr, col, value, mat, reduce)
else:
out, arg_out = torch.ops.torch_sparse_cpu.spmm(
rowptr, col, value, mat, reduce)
ctx.reduce = reduce ctx.reduce = reduce
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr, ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
......
import torch import torch
from torch_sparse import transpose, to_scipy, from_scipy, coalesce from torch_sparse import transpose, to_scipy, from_scipy, coalesce
import torch_sparse.spspmm_cpu # import torch_sparse.spspmm_cpu
if torch.cuda.is_available(): # if torch.cuda.is_available():
import torch_sparse.spspmm_cuda # import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
...@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): ...@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
raise NotImplementedError
if indexA.is_cuda and coalesced: if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k) indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n) indexB, valueB = coalesce(indexB, valueB, k, n)
......
This diff is collapsed.
from typing import Any
import torch import torch
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
torch.ops.load_library('torch_sparse/convert_cpu.so') torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so') torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so') torch.ops.load_library('torch_sparse/spmm_cpu.so')
...@@ -14,10 +21,5 @@ except OSError as e: ...@@ -14,10 +21,5 @@ except OSError as e:
raise e raise e
def ext(is_cuda): def is_scalar(other: Any) -> bool:
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float) return isinstance(other, int) or isinstance(other, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment