Commit d8a8ebab authored by rusty1s's avatar rusty1s
Browse files

add remove diag

parent 072a17c4
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_cat(dtype, device):
index = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
mat.fill_cache_()
mat = mat.remove_diag()
index, value = mat.coo()
assert index.tolist() == [[0, 1], [1, 2]]
assert value.tolist() == [2, 3]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 1, 0]
assert mat.storage.colcount.tolist() == [0, 1, 1]
import torch
def add_diag(src, value=None, k=0):
pass
def remove_diag(src, k=0):
index, value = src.coo()
row, col = index
mask = row == col if k == 0 else row == (col + k)
inv_mask = ~mask
index = index[:, inv_mask]
if src.has_value():
value = value[inv_mask]
rowcount = None
if src.storage.has_rowcount():
rowcount = src.storage.rowcount.clone()
rowcount[row[mask]] -= 1
colcount = None
if src.storage.has_colcount():
colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1
storage = src.storage.__class__(index, value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
return src.__class__.from_storage(storage)
......@@ -3,7 +3,7 @@ import torch_scatter
from torch_scatter import segment_csr
def __reduce__(src, dim=None, reduce='add', deterministic=False):
def reduction(src, dim=None, reduce='add', deterministic=False):
assert reduce in ['add', 'mean', 'min', 'max']
if dim is None and src.has_value():
......@@ -84,16 +84,16 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
def sum(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='add', deterministic=deterministic)
return reduction(src, dim, reduce='add', deterministic=deterministic)
def mean(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='mean', deterministic=deterministic)
return reduction(src, dim, reduce='mean', deterministic=deterministic)
def min(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='min', deterministic=deterministic)
return reduction(src, dim, reduce='min', deterministic=deterministic)
def max(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='max', deterministic=deterministic)
return reduction(src, dim, reduce='max', deterministic=deterministic)
......@@ -11,7 +11,7 @@ from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.add import add, add_nnz
from torch_sparse.diag import remove_diag
class SparseTensor(object):
......@@ -390,12 +390,14 @@ SparseTensor.index_select = index_select
SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz
SparseTensor.reduction = torch_sparse.reduce.reduction
SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.add = add
SparseTensor.add_nnz = add_nnz
SparseTensor.remove_diag = remove_diag
# SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz
# def remove_diag(self):
# raise NotImplementedError
......@@ -455,193 +457,3 @@ SparseTensor.add_nnz = add_nnz
# def div_(self, layout=None):
# raise NotImplementedError
if __name__ == '__main__':
from torch_geometric.datasets import Reddit, Planetoid # noqa
import time # noqa
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
dataset = Reddit('/tmp/Reddit')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device)
value = torch.randn((data.num_edges, 8), device=device)
mat = SparseTensor(data.edge_index, value)
print(mat)
t = time.perf_counter()
torch.cuda.synchronize()
out = mat.sum(dim=1)
torch.cuda.synchronize()
print(time.perf_counter() - t)
print(out.size())
# perm = torch.arange(data.num_nodes)
# perm = torch.randperm(data.num_nodes)
# mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
# mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
# add(mat1, mat2)
# # print(mat2)
# raise NotImplementedError
# for _ in range(10):
# x = torch.randn(1000, 1000, device=device).sum()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# mat[perm]
# torch.cuda.synchronize()
# print(time.perf_counter() - t)
# index = torch.tensor([
# [0, 1, 1, 2, 2],
# [1, 2, 2, 2, 3],
# ])
# value = torch.tensor([1, 2, 3, 4, 5])
# mat = SparseTensor(index, value)
# print(mat)
# print(mat.coalesce())
# index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# mask[:3] = True
# print(mat[1].size())
# print(mat[1, 1].size())
# print(mat[..., -1].size())
# print(mat[:10, ..., -1].size())
# print(mat[:, -1].size())
# print(mat[1, :, -1].size())
# print(mat[1:4, 1:4].size())
# print(mat[index].size())
# print(mat[index, index].size())
# print(mat[mask, index].size())
# mat[::-1]
# mat[::2]
# mat1 = SparseTensor.from_dense(mat1.to_dense())
# print(mat1)
# mat = SparseTensor.from_torch_sparse_coo_tensor(
# mat1.to_torch_sparse_coo_tensor())
# mat = SparseTensor.from_scipy(mat.to_scipy(layout='csc'))
# print(mat)
# index = torch.tensor([0, 2])
# mat2 = mat1.index_select(2, index)
# index = torch.randperm(data.num_nodes)[:data.num_nodes - 500]
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# mask[index] = True
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(0, index)
# print(time.perf_counter() - t)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.masked_select(0, mask)
# print(time.perf_counter() - t)
# mat2 = mat1.narrow(1, start=0, length=3)
# print(mat2)
# index = torch.randperm(data.num_nodes)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(0, index)
# print(time.perf_counter() - t)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(1, index)
# print(time.perf_counter() - t)
# raise NotImplementedError
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.t().index_select(0, index).t()
# print(time.perf_counter() - t)
# print(mat1)
# mat1.index_select((0, 1), torch.tensor([0, 1, 2, 3, 4]))
# print(mat3)
# print(mat3.storage.rowcount)
# print(mat1)
# (row, col), value = mat1.coo()
# mask = row < 3
# t = time.perf_counter()
# for _ in range(10000):
# mat2 = mat1.narrow(1, start=10, length=2690)
# print(time.perf_counter() - t)
# # print(mat1.to_dense().size())
# print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
# print(mat1.to_scipy(layout='coo').todense().shape)
# print(mat1.to_scipy(layout='csr').todense().shape)
# print(mat1.to_scipy(layout='csc').todense().shape)
# print(mat1.is_quadratic())
# print(mat1.is_symmetric())
# print(mat1.cached_keys())
# mat1 = mat1.t()
# print(mat1.cached_keys())
# mat1 = mat1.t()
# print(mat1.cached_keys())
# print('-------- NARROW ----------')
# t = time.perf_counter()
# for _ in range(100):
# out = mat1.narrow(dim=0, start=10, length=10)
# # out.storage.colptr
# print(time.perf_counter() - t)
# print(out)
# print(out.cached_keys())
# t = time.perf_counter()
# for _ in range(100):
# out = mat1.narrow(dim=1, start=10, length=2000)
# # out.storage.colptr
# print(time.perf_counter() - t)
# print(out)
# print(out.cached_keys())
# mat1 = mat1.narrow(0, start=10, length=10)
# mat1.storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size())
# mat1 = mat1.t()
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
# device=device)
# mat2 = mat2.coalesce()
# mat2 = mat2.t().coalesce()
# index1, value1 = mat1.coo()
# index2, value2 = mat2._indices(), mat2._values()
# assert torch.allclose(index1, index2)
# out1 = mat1.to_dense()
# out2 = mat2.to_dense()
# assert torch.allclose(out1, out2)
# out = 2 + mat1
# print(out)
# # mat1[1]
# # mat1[1, 1]
# # mat1[..., -1]
# # mat1[:, -1]
# # mat1[1:4, 1:4]
# # mat1[torch.tensor([0, 1, 2])]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment