Commit 1cb25232 authored by limm's avatar limm
Browse files

push 0.6.15 version

parent e8309f27
import torch
from torch_sparse import SparseTensor
def test_ego_k_hop_sample_adj():
rowptr = torch.tensor([0, 3, 5, 9, 10, 12, 14])
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
_ = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
nid = torch.tensor([0, 1])
fn = torch.ops.torch_sparse.ego_k_hop_sample_adj
out = fn(rowptr, col, nid, 1, 3, False)
rowptr, col, nid, eid, ptr, root_n_id = out
assert nid.tolist() == [0, 1, 2, 3, 0, 1, 2]
assert rowptr.tolist() == [0, 3, 5, 7, 8, 10, 12, 14]
# row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]
assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5]
assert eid.tolist() == [0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]
assert ptr.tolist() == [0, 4, 7]
assert root_n_id.tolist() == [0, 5]
from itertools import product
import pytest
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device):
mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value().tolist() == [1, 1, 1]
assert mat.storage.value().dtype == dtype
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, has_value=False, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value() is None
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, 4, fill_cache=True, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 4)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1, 0]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, fill_cache=True, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (4, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1, 0]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
from itertools import product
import pytest
import torch
import torch_scatter
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
from .utils import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce',
product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
src = torch.randn((10, 8), dtype=dtype, device=device)
src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_()
row, col, value = src.coo()
other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True)
src_col = other.index_select(-2, col) * value.unsqueeze(-1)
expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
if reduce == 'min':
expected[expected > 1000] = 0
if reduce == 'max':
expected[expected < -1000] = 0
grad_out = torch.randn_like(expected)
expected.backward(grad_out)
expected_grad_value = value.grad
value.grad = None
expected_grad_other = other.grad
other.grad = None
out = matmul(src, other, reduce)
out.backward(grad_out)
assert torch.allclose(expected, out, atol=1e-2)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-2)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)
src = SparseTensor.from_dense(src)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
assert value.tolist() == [1, 1, 1]
src.set_value_(None)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert not out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
try:
rowptr = torch.tensor([0, 1])
col = torch.tensor([0])
torch.ops.torch_sparse.partition(rowptr, col, None, 1, True)
with_metis = True
except RuntimeError:
with_metis = False
@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
@pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device, weighted):
mat1 = torch.randn(6 * 6, device=device).view(6, 6)
mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
mat3 = torch.ones(6 * 6, device=device).view(6, 6)
vec1 = None
vec2 = torch.rand(6, device=device)
for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
mat = SparseTensor.from_dense(mat)
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 2
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 3
assert perm.numel() == 6
import torch
from torch_sparse import SparseTensor
neighbor_sample = torch.ops.torch_sparse.neighbor_sample
def test_neighbor_sample():
adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
colptr, row, _ = adj.csc()
# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(colptr, row, torch.tensor([0]), [1], False, False)
assert out[0].tolist() == [0]
assert out[1].tolist() == []
assert out[2].tolist() == []
# Sampling should work:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
# Sampling with more hops:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1, 1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
def test_neighbor_sample_seed():
colptr = torch.tensor([0, 3, 6, 9])
row = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
input_nodes = torch.tensor([0, 1])
torch.manual_seed(42)
out1 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
torch.manual_seed(42)
out2 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
for data1, data2 in zip(out1, out2):
assert data1.tolist() == data2.tolist()
import torch
from torch_sparse.tensor import SparseTensor
def test_overload():
row = torch.tensor([0, 1, 1, 2, 2])
col = torch.tensor([1, 0, 2, 1, 2])
mat = SparseTensor(row=row, col=col)
other = torch.tensor([1, 2, 3]).view(3, 1)
other + mat
mat + other
other * mat
mat * other
other = torch.tensor([1, 2, 3]).view(1, 3)
other + mat
mat + other
other * mat
mat * other
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_permute(device):
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
value = tensor([1, 2, 3, 4, 5], torch.float, device)
adj = SparseTensor(row=row, col=col, value=value)
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
assert row.tolist() == [0, 1, 1, 2, 2]
assert col.tolist() == [1, 0, 1, 0, 2]
assert value.tolist() == [3, 2, 1, 4, 5]
import torch
from torch_sparse.tensor import SparseTensor
def test_saint_subgraph():
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col)
node_idx = torch.tensor([0, 1, 2])
adj, edge_index = adj.saint_subgraph(node_idx)
import torch
from torch_sparse import SparseTensor, sample, sample_adj
def test_sample():
row = torch.tensor([0, 0, 2, 2])
col = torch.tensor([1, 2, 0, 1])
adj = SparseTensor(row=row, col=col, sparse_sizes=(3, 3))
out = sample(adj, num_neighbors=1)
assert out.min() >= 0 and out.max() <= 2
def test_sample_adj():
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.arange(row.size(0))
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(6, 6))
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=-1)
assert n_id.tolist() == [2, 3, 4, 5, 0, 1]
row, col, val = out.coo()
assert row.tolist() == [0, 0, 0, 0, 1, 2, 2, 3, 3]
assert col.tolist() == [2, 3, 4, 5, 4, 0, 3, 0, 2]
assert val.tolist() == [7, 8, 5, 6, 9, 10, 11, 12, 13]
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=2,
replace=True)
assert out.nnz() == 8
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=2,
replace=False)
assert out.nnz() == 7 # node 3 has only one edge...
from itertools import product
import pytest
import torch
from torch_sparse import spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
out = spmm(index, value, 3, 3, x)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, SparseTensor
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device)
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device):
x = SparseTensor(
row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
device=device),
col=torch.tensor(
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
device=device),
value=torch.tensor([
1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
-2**-0.5, 2**-0.5, -2**-0.5
], dtype=dtype, device=device),
)
expected = torch.eye(10, dtype=dtype, device=device)
out = x @ x.to_dense().t()
assert torch.allclose(out, expected, atol=1e-2)
out = x @ x.t()
out = out.to_dense()
assert torch.allclose(out, expected, atol=1e-2)
from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('device', devices)
def test_ind2ptr(device):
row = tensor([2, 2, 4, 5, 5, 6], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 2, 2, 3, 5, 6, 6]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 6)
assert row.tolist() == [2, 2, 4, 5, 5, 6]
row = tensor([], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 0)
assert row.tolist() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value() is None
assert storage.sparse_sizes() == (2, 2)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
assert storage.sparse_sizes() == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage._row.tolist() == row.tolist()
assert storage._col.tolist() == col.tolist()
assert storage._value is None
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
storage.fill_cache_()
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
rowcount=storage._rowcount, colptr=storage._colptr,
colcount=storage._colcount,
csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage.clear_cache_()
assert storage._rowcount is None
assert storage._rowptr is not None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.has_value()
storage.set_value_(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage.set_value_(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.set_value(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize((3, 3))
assert storage.sparse_sizes() == (3, 3)
new_storage = storage.copy()
assert new_storage != storage
assert new_storage.col().data_ptr() == storage.col().data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.col().data_ptr() != storage.col().data_ptr()
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == row.tolist()
assert storage.col().tolist() == col.tolist()
assert storage.value().tolist() == value.tolist()
assert not storage.is_coalesced()
storage = storage.coalesce()
assert storage.is_coalesced()
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
storage = SparseStorage(row=row, col=col)
storage = storage.sparse_reshape(2, 8)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
storage = storage.sparse_reshape(-1, 4)
assert storage.sparse_sizes() == (4, 4)
assert storage.row().tolist() == [0, 1, 2, 3]
assert storage.col().tolist() == [0, 1, 2, 3]
storage = storage.sparse_reshape(2, -1)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
m = 50
n = 40
k = 10
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
bool2.scatter_(0, idx2, 1)
# idx1 and idx2 may have duplicates
k1_bool = bool1.nonzero().size(0)
k2_bool = bool2.nonzero().size(0)
idx1np = idx1.cpu().numpy()
idx2np = idx2.cpu().numpy()
bool1np = bool1.cpu().numpy()
bool2np = bool2.cpu().numpy()
idx1list = idx1np.tolist()
idx2list = idx2np.tolist()
bool1list = bool1np.tolist()
bool2list = bool2np.tolist()
assert mat[:k, :k].sizes() == [k, k]
assert mat[..., :k].sizes() == [m, k]
assert mat[idx1, idx2].sizes() == [k, k]
assert mat[idx1np, idx2np].sizes() == [k, k]
assert mat[idx1list, idx2list].sizes() == [k, k]
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
assert mat[idx1].sizes() == [k, n]
assert mat[idx1np].sizes() == [k, n]
assert mat[idx1list].sizes() == [k, n]
assert mat[bool1].sizes() == [k1_bool, n]
assert mat[bool1np].sizes() == [k1_bool, n]
assert mat[bool1list].sizes() == [k1_bool, n]
@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
row = torch.tensor([0, 0, 0, 1, 1], device=device)
col = torch.tensor([0, 1, 2, 0, 2], device=device)
value = torch.arange(1, 6, device=device)
mat = SparseTensor(row=row, col=col, value=value)
assert not mat.is_symmetric()
mat = mat.to_symmetric()
assert mat.is_symmetric()
assert mat.to_dense().tolist() == [
[2, 6, 3],
[6, 0, 5],
[3, 5, 0],
]
def test_equal():
row = torch.tensor([0, 0, 0, 1, 1])
col = torch.tensor([0, 1, 2, 0, 2])
value = torch.arange(1, 6)
matA = SparseTensor(row=row, col=col, value=value)
matB = SparseTensor(row=row, col=col, value=value)
col = torch.tensor([0, 1, 2, 0, 1])
matC = SparseTensor(row=row, col=col, value=value)
assert id(matA) != id(matB)
assert matA == matB
assert id(matA) != id(matC)
assert matA != matC
from itertools import product
import pytest
import torch
from torch_sparse import transpose
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3, 4], dtype, device)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [1, 4, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose(dtype, device):
row = torch.tensor([1, 0, 1, 0, 2, 1], device=device)
col = torch.tensor([0, 1, 1, 1, 0, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype,
device)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
def tensor(x, dtype, device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
......@@ -3,18 +3,18 @@ import os.path as osp
import torch
__version__ = '0.6.13'
__version__ = '0.6.15'
for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample',
'_relabel'
]:
hip_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_hip', [osp.dirname(__file__)])
cuda_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cuda', [osp.dirname(__file__)])
cpu_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cpu', [osp.dirname(__file__)])
spec = hip_spec or cpu_spec
spec = cuda_spec or cpu_spec
if spec is not None:
torch.ops.load_library(spec.origin)
else: # pragma: no cover
......@@ -22,11 +22,20 @@ for library in [
f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_sparse.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
if t_major != major:
raise RuntimeError(
f'Detected that PyTorch and torch_sparse were compiled with '
f'different CUDA versions. PyTorch has CUDA version '
f'{t_major}.{t_minor} and torch_sparse has CUDA version '
f'{major}.{minor}. Please reinstall the torch_sparse that '
f'matches your PyTorch install.')
from .storage import SparseStorage # noqa
from .tensor import SparseTensor # noqa
......
......@@ -8,10 +8,13 @@ def spmm(index: Tensor, value: Tensor, m: int, n: int,
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix, either of
floating-point or integer type. Does not work for boolean and
complex number data types.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix.
matrix (:class:`Tensor`): The dense matrix of same type as
:obj:`value`.
:rtype: :class:`Tensor`
"""
......
......@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False):
def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
assert row is not None or rowptr is not None
assert col is not None
......@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
......@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
......@@ -313,10 +317,20 @@ class SparseStorage(object):
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=self._value,
sparse_sizes=(num_rows, num_cols),
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def fill_cache_(self):
self.row()
......@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def clone(self):
row = self._row
......@@ -495,11 +520,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value
......@@ -508,9 +542,7 @@ class SparseStorage(object):
return self
else:
return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
value.to(dtype=dtype, non_blocking=non_blocking),
layout='coo')
else:
return self
......@@ -548,11 +580,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
......@@ -587,11 +628,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=new_col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def pin_memory(self):
row = self._row
......@@ -620,11 +670,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def is_pinned(self) -> bool:
is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment