Commit 1cb25232 authored by limm's avatar limm
Browse files

push 0.6.15 version

parent e8309f27
import torch
from torch_sparse import SparseTensor
def test_ego_k_hop_sample_adj():
rowptr = torch.tensor([0, 3, 5, 9, 10, 12, 14])
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
_ = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
nid = torch.tensor([0, 1])
fn = torch.ops.torch_sparse.ego_k_hop_sample_adj
out = fn(rowptr, col, nid, 1, 3, False)
rowptr, col, nid, eid, ptr, root_n_id = out
assert nid.tolist() == [0, 1, 2, 3, 0, 1, 2]
assert rowptr.tolist() == [0, 3, 5, 7, 8, 10, 12, 14]
# row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]
assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5]
assert eid.tolist() == [0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]
assert ptr.tolist() == [0, 4, 7]
assert root_n_id.tolist() == [0, 5]
from itertools import product
import pytest
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device):
mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value().tolist() == [1, 1, 1]
assert mat.storage.value().dtype == dtype
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, has_value=False, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value() is None
assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, 4, fill_cache=True, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (3, 4)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1, 0]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, fill_cache=True, device=device)
assert mat.device() == device
assert mat.storage.sparse_sizes() == (4, 3)
assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]
assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.num_cached_keys() == 5
assert mat.storage.rowcount().tolist() == [1, 1, 1, 0]
assert mat.storage.colptr().tolist() == [0, 1, 2, 3]
assert mat.storage.colcount().tolist() == [1, 1, 1]
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
from itertools import product
import pytest
import torch
import torch_scatter
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
from .utils import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce',
product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
src = torch.randn((10, 8), dtype=dtype, device=device)
src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_()
row, col, value = src.coo()
other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True)
src_col = other.index_select(-2, col) * value.unsqueeze(-1)
expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
if reduce == 'min':
expected[expected > 1000] = 0
if reduce == 'max':
expected[expected < -1000] = 0
grad_out = torch.randn_like(expected)
expected.backward(grad_out)
expected_grad_value = value.grad
value.grad = None
expected_grad_other = other.grad
other.grad = None
out = matmul(src, other, reduce)
out.backward(grad_out)
assert torch.allclose(expected, out, atol=1e-2)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-2)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)
src = SparseTensor.from_dense(src)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
assert value.tolist() == [1, 1, 1]
src.set_value_(None)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert not out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
try:
rowptr = torch.tensor([0, 1])
col = torch.tensor([0])
torch.ops.torch_sparse.partition(rowptr, col, None, 1, True)
with_metis = True
except RuntimeError:
with_metis = False
@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
@pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device, weighted):
mat1 = torch.randn(6 * 6, device=device).view(6, 6)
mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
mat3 = torch.ones(6 * 6, device=device).view(6, 6)
vec1 = None
vec2 = torch.rand(6, device=device)
for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
mat = SparseTensor.from_dense(mat)
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 2
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 3
assert perm.numel() == 6
import torch
from torch_sparse import SparseTensor
neighbor_sample = torch.ops.torch_sparse.neighbor_sample
def test_neighbor_sample():
adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
colptr, row, _ = adj.csc()
# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(colptr, row, torch.tensor([0]), [1], False, False)
assert out[0].tolist() == [0]
assert out[1].tolist() == []
assert out[2].tolist() == []
# Sampling should work:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
# Sampling with more hops:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1, 1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
def test_neighbor_sample_seed():
colptr = torch.tensor([0, 3, 6, 9])
row = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
input_nodes = torch.tensor([0, 1])
torch.manual_seed(42)
out1 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
torch.manual_seed(42)
out2 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
for data1, data2 in zip(out1, out2):
assert data1.tolist() == data2.tolist()
import torch
from torch_sparse.tensor import SparseTensor
def test_overload():
row = torch.tensor([0, 1, 1, 2, 2])
col = torch.tensor([1, 0, 2, 1, 2])
mat = SparseTensor(row=row, col=col)
other = torch.tensor([1, 2, 3]).view(3, 1)
other + mat
mat + other
other * mat
mat * other
other = torch.tensor([1, 2, 3]).view(1, 3)
other + mat
mat + other
other * mat
mat * other
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_permute(device):
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
value = tensor([1, 2, 3, 4, 5], torch.float, device)
adj = SparseTensor(row=row, col=col, value=value)
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
assert row.tolist() == [0, 1, 1, 2, 2]
assert col.tolist() == [1, 0, 1, 0, 2]
assert value.tolist() == [3, 2, 1, 4, 5]
import torch
from torch_sparse.tensor import SparseTensor
def test_saint_subgraph():
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col)
node_idx = torch.tensor([0, 1, 2])
adj, edge_index = adj.saint_subgraph(node_idx)
import torch
from torch_sparse import SparseTensor, sample, sample_adj
def test_sample():
row = torch.tensor([0, 0, 2, 2])
col = torch.tensor([1, 2, 0, 1])
adj = SparseTensor(row=row, col=col, sparse_sizes=(3, 3))
out = sample(adj, num_neighbors=1)
assert out.min() >= 0 and out.max() <= 2
def test_sample_adj():
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.arange(row.size(0))
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(6, 6))
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=-1)
assert n_id.tolist() == [2, 3, 4, 5, 0, 1]
row, col, val = out.coo()
assert row.tolist() == [0, 0, 0, 0, 1, 2, 2, 3, 3]
assert col.tolist() == [2, 3, 4, 5, 4, 0, 3, 0, 2]
assert val.tolist() == [7, 8, 5, 6, 9, 10, 11, 12, 13]
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=2,
replace=True)
assert out.nnz() == 8
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=2,
replace=False)
assert out.nnz() == 7 # node 3 has only one edge...
from itertools import product
import pytest
import torch
from torch_sparse import spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
out = spmm(index, value, 3, 3, x)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, SparseTensor
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device)
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device):
x = SparseTensor(
row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
device=device),
col=torch.tensor(
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
device=device),
value=torch.tensor([
1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
-2**-0.5, 2**-0.5, -2**-0.5
], dtype=dtype, device=device),
)
expected = torch.eye(10, dtype=dtype, device=device)
out = x @ x.to_dense().t()
assert torch.allclose(out, expected, atol=1e-2)
out = x @ x.t()
out = out.to_dense()
assert torch.allclose(out, expected, atol=1e-2)
from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('device', devices)
def test_ind2ptr(device):
row = tensor([2, 2, 4, 5, 5, 6], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 2, 2, 3, 5, 6, 6]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 6)
assert row.tolist() == [2, 2, 4, 5, 5, 6]
row = tensor([], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 0)
assert row.tolist() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value() is None
assert storage.sparse_sizes() == (2, 2)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
assert storage.sparse_sizes() == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage._row.tolist() == row.tolist()
assert storage._col.tolist() == col.tolist()
assert storage._value is None
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
storage.fill_cache_()
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
rowcount=storage._rowcount, colptr=storage._colptr,
colcount=storage._colcount,
csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage.clear_cache_()
assert storage._rowcount is None
assert storage._rowptr is not None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.has_value()
storage.set_value_(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage.set_value_(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.set_value(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize((3, 3))
assert storage.sparse_sizes() == (3, 3)
new_storage = storage.copy()
assert new_storage != storage
assert new_storage.col().data_ptr() == storage.col().data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.col().data_ptr() != storage.col().data_ptr()
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == row.tolist()
assert storage.col().tolist() == col.tolist()
assert storage.value().tolist() == value.tolist()
assert not storage.is_coalesced()
storage = storage.coalesce()
assert storage.is_coalesced()
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
storage = SparseStorage(row=row, col=col)
storage = storage.sparse_reshape(2, 8)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
storage = storage.sparse_reshape(-1, 4)
assert storage.sparse_sizes() == (4, 4)
assert storage.row().tolist() == [0, 1, 2, 3]
assert storage.col().tolist() == [0, 1, 2, 3]
storage = storage.sparse_reshape(2, -1)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
m = 50
n = 40
k = 10
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
bool2.scatter_(0, idx2, 1)
# idx1 and idx2 may have duplicates
k1_bool = bool1.nonzero().size(0)
k2_bool = bool2.nonzero().size(0)
idx1np = idx1.cpu().numpy()
idx2np = idx2.cpu().numpy()
bool1np = bool1.cpu().numpy()
bool2np = bool2.cpu().numpy()
idx1list = idx1np.tolist()
idx2list = idx2np.tolist()
bool1list = bool1np.tolist()
bool2list = bool2np.tolist()
assert mat[:k, :k].sizes() == [k, k]
assert mat[..., :k].sizes() == [m, k]
assert mat[idx1, idx2].sizes() == [k, k]
assert mat[idx1np, idx2np].sizes() == [k, k]
assert mat[idx1list, idx2list].sizes() == [k, k]
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
assert mat[idx1].sizes() == [k, n]
assert mat[idx1np].sizes() == [k, n]
assert mat[idx1list].sizes() == [k, n]
assert mat[bool1].sizes() == [k1_bool, n]
assert mat[bool1np].sizes() == [k1_bool, n]
assert mat[bool1list].sizes() == [k1_bool, n]
@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
row = torch.tensor([0, 0, 0, 1, 1], device=device)
col = torch.tensor([0, 1, 2, 0, 2], device=device)
value = torch.arange(1, 6, device=device)
mat = SparseTensor(row=row, col=col, value=value)
assert not mat.is_symmetric()
mat = mat.to_symmetric()
assert mat.is_symmetric()
assert mat.to_dense().tolist() == [
[2, 6, 3],
[6, 0, 5],
[3, 5, 0],
]
def test_equal():
row = torch.tensor([0, 0, 0, 1, 1])
col = torch.tensor([0, 1, 2, 0, 2])
value = torch.arange(1, 6)
matA = SparseTensor(row=row, col=col, value=value)
matB = SparseTensor(row=row, col=col, value=value)
col = torch.tensor([0, 1, 2, 0, 1])
matC = SparseTensor(row=row, col=col, value=value)
assert id(matA) != id(matB)
assert matA == matB
assert id(matA) != id(matC)
assert matA != matC
from itertools import product
import pytest
import torch
from torch_sparse import transpose
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3, 4], dtype, device)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [1, 4, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose(dtype, device):
row = torch.tensor([1, 0, 1, 0, 2, 1], device=device)
col = torch.tensor([0, 1, 1, 1, 0, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype,
device)
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
def tensor(x, dtype, device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
...@@ -3,18 +3,18 @@ import os.path as osp ...@@ -3,18 +3,18 @@ import os.path as osp
import torch import torch
__version__ = '0.6.13' __version__ = '0.6.15'
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample', '_saint', '_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample',
'_relabel' '_relabel'
]: ]:
hip_spec = importlib.machinery.PathFinder().find_spec( cuda_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_hip', [osp.dirname(__file__)]) f'{library}_cuda', [osp.dirname(__file__)])
cpu_spec = importlib.machinery.PathFinder().find_spec( cpu_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cpu', [osp.dirname(__file__)]) f'{library}_cpu', [osp.dirname(__file__)])
spec = hip_spec or cpu_spec spec = cuda_spec or cpu_spec
if spec is not None: if spec is not None:
torch.ops.load_library(spec.origin) torch.ops.load_library(spec.origin)
else: # pragma: no cover else: # pragma: no cover
...@@ -22,11 +22,20 @@ for library in [ ...@@ -22,11 +22,20 @@ for library in [
f"{osp.dirname(__file__)}") f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_sparse.cuda_version() cuda_version = torch.ops.torch_sparse.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000: if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2]) major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else: else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3]) major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
if t_major != major:
raise RuntimeError(
f'Detected that PyTorch and torch_sparse were compiled with '
f'different CUDA versions. PyTorch has CUDA version '
f'{t_major}.{t_minor} and torch_sparse has CUDA version '
f'{major}.{minor}. Please reinstall the torch_sparse that '
f'matches your PyTorch install.')
from .storage import SparseStorage # noqa from .storage import SparseStorage # noqa
from .tensor import SparseTensor # noqa from .tensor import SparseTensor # noqa
......
...@@ -8,10 +8,13 @@ def spmm(index: Tensor, value: Tensor, m: int, n: int, ...@@ -8,10 +8,13 @@ def spmm(index: Tensor, value: Tensor, m: int, n: int,
Args: Args:
index (:class:`LongTensor`): The index tensor of sparse matrix. index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix, either of
floating-point or integer type. Does not work for boolean and
complex number data types.
m (int): The first dimension of sparse matrix. m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix. n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix. matrix (:class:`Tensor`): The dense matrix of same type as
:obj:`value`.
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
......
...@@ -30,19 +30,21 @@ class SparseStorage(object): ...@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor] _csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor] _csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None, def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None, rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None, colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None, colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None, csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None, csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False, is_sorted: bool = False,
trust_data: bool = False): trust_data: bool = False,
):
assert row is not None or rowptr is not None assert row is not None or rowptr is not None
assert col is not None assert col is not None
...@@ -240,7 +242,8 @@ class SparseStorage(object): ...@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def sparse_sizes(self) -> Tuple[int, int]: def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes return self._sparse_sizes
...@@ -290,7 +293,8 @@ class SparseStorage(object): ...@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def sparse_reshape(self, num_rows: int, num_cols: int): def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1 assert num_rows > 0 or num_rows == -1
...@@ -313,10 +317,20 @@ class SparseStorage(object): ...@@ -313,10 +317,20 @@ class SparseStorage(object):
col = idx % num_cols col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value, return SparseStorage(
sparse_sizes=(num_rows, num_cols), rowcount=None, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True, trust_data=True) col=col,
value=self._value,
sparse_sizes=(num_rows, num_cols),
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -413,10 +427,20 @@ class SparseStorage(object): ...@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, rowcount=None, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True, trust_data=True) col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -466,7 +490,8 @@ class SparseStorage(object): ...@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -495,11 +520,20 @@ class SparseStorage(object): ...@@ -495,11 +520,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount,
colcount=colcount, csr2csc=csr2csc, colptr=colptr,
csc2csr=csc2csr, is_sorted=True, trust_data=True) colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def type(self, dtype: torch.dtype, non_blocking: bool = False): def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value value = self._value
...@@ -508,9 +542,7 @@ class SparseStorage(object): ...@@ -508,9 +542,7 @@ class SparseStorage(object):
return self return self
else: else:
return self.set_value( return self.set_value(
value.to( value.to(dtype=dtype, non_blocking=non_blocking),
dtype=dtype,
non_blocking=non_blocking),
layout='coo') layout='coo')
else: else:
return self return self
...@@ -548,11 +580,20 @@ class SparseStorage(object): ...@@ -548,11 +580,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking) csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount,
colcount=colcount, csr2csc=csr2csc, colptr=colptr,
csc2csr=csc2csr, is_sorted=True, trust_data=True) colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking) return self.to_device(device=tensor.device, non_blocking=non_blocking)
...@@ -587,11 +628,20 @@ class SparseStorage(object): ...@@ -587,11 +628,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.cuda() csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value, return SparseStorage(
row=row,
rowptr=rowptr,
col=new_col,
value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount,
colcount=colcount, csr2csc=csr2csc, colptr=colptr,
csc2csr=csc2csr, is_sorted=True, trust_data=True) colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -620,11 +670,20 @@ class SparseStorage(object): ...@@ -620,11 +670,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount,
colcount=colcount, csr2csc=csr2csc, colptr=colptr,
csc2csr=csc2csr, is_sorted=True, trust_data=True) colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment