Commit f3469f1a authored by rusty1s's avatar rusty1s
Browse files

test storage

parent f3b7fb50
......@@ -10,7 +10,6 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_add(dtype, device):
print()
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index)
......
import copy
from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index)
assert storage.index.tolist() == index.tolist()
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value is None
assert storage.sparse_size() == (2, 2)
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(index, value)
assert storage.index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.sparse_size() == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index)
assert storage._index.tolist() == index.tolist()
assert storage._value is None
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.cached_keys() == []
storage.fill_cache_()
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
storage = SparseStorage(index, storage.value, storage.sparse_size(),
storage.rowcount, storage.rowptr, storage.colcount,
storage.colptr, storage.csr2csc, storage.csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
storage.clear_cache_()
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.cached_keys() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(index, value)
assert storage.has_value()
storage.set_value_(value, layout='csc')
assert storage.value.tolist() == [1, 3, 2, 4]
storage.set_value_(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4]
storage = storage.set_value(value, layout='csc')
assert storage.value.tolist() == [1, 3, 2, 4]
storage = storage.set_value(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4]
storage.sparse_resize_(3, 3)
assert storage.sparse_size() == (3, 3)
new_storage = copy.copy(storage)
assert new_storage != storage
assert new_storage.index.data_ptr() == storage.index.data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr()
new_storage = copy.deepcopy(storage)
assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr()
storage.apply_value_(lambda x: x + 1)
assert storage.value.tolist() == [2, 3, 4, 5]
storage = storage.apply_value(lambda x: x + 1)
assert storage.value.tolist() == [3, 4, 5, 6]
storage.apply_(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long
assert storage.value.dtype == torch.long
storage = storage.apply(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long
assert storage.value.dtype == torch.long
storage.clear_cache_()
assert storage.map(lambda x: x.numel()) == [8, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
pass
......@@ -3,8 +3,8 @@ import torch
dtypes = [torch.float]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
# if torch.cuda.is_available():
# devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
def tensor(x, dtype, device):
......
......@@ -38,17 +38,9 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self,
index,
value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
......@@ -97,7 +89,7 @@ class SparseStorage(object):
if not is_sorted:
idx = sparse_size[1] * index[0] + index[1]
# Only sort if necessary...
if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
perm = idx.argsort()
index = index[:, perm]
value = None if value is None else value[perm]
......@@ -164,7 +156,7 @@ class SparseStorage(object):
def sparse_resize_(self, *sizes):
assert len(sizes) == 2
self._sparse_size == sizes
self._sparse_size = sizes
return self
@cached_property
......@@ -269,7 +261,7 @@ class SparseStorage(object):
self._index = func(self._index)
self._value = optional(func, self._value)
for key in self.cached_keys():
setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
return self
def apply(self, func):
......@@ -292,34 +284,3 @@ class SparseStorage(object):
data += [func(self.value)]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
return data
if __name__ == '__main__':
from torch_geometric.datasets import Reddit, Planetoid # noqa
import time # noqa
import copy # noqa
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device)
edge_index = data.edge_index
storage = SparseStorage(edge_index, is_sorted=True)
t = time.perf_counter()
storage.fill_cache_()
print(time.perf_counter() - t)
t = time.perf_counter()
storage.clear_cache_()
storage.fill_cache_()
print(time.perf_counter() - t)
print(storage)
# storage = storage.clone()
# print(storage)
storage = copy.copy(storage)
print(storage)
print(id(storage))
storage = copy.deepcopy(storage)
print(storage)
storage.fill_cache_()
storage.clear_cache_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment