Commit 9fe44a44 authored by rusty1s's avatar rusty1s
Browse files

test code

parent ac5d7a78
......@@ -17,8 +17,8 @@ def __is_scalar__(x):
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
assert index.dim() == 2 and index.size(0) == 2
self._storage = SparseStorage(
index[0], index[1], value, sparse_size, is_sorted=is_sorted)
self._storage = SparseStorage(index[0], index[1], value, sparse_size,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
......@@ -184,8 +184,8 @@ class SparseTensor(object):
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(
torch.full((self.nnz(), ), other + 1), 'coo')
return self.set_value(torch.full((self.nnz(), ), other + 1),
'coo')
elif torch.is_tensor(other):
if layout is None:
layout = 'coo'
......@@ -249,9 +249,7 @@ class SparseTensor(object):
return torch.sparse_coo_tensor(
index,
torch.ones_like(self._row, dtype) if value is None else value,
self.size(),
device=self.device,
requires_grad=requires_grad)
self.size(), device=self.device, requires_grad=requires_grad)
def __repr__(self):
i = ' ' * 6
......@@ -292,8 +290,8 @@ if __name__ == '__main__':
print(mat1)
mat1 = mat1.t()
mat2 = torch.sparse_coo_tensor(
data.edge_index, torch.ones(data.num_edges), device=device)
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
device=device)
mat2 = mat2.coalesce()
mat2 = mat2.t().coalesce()
......
import torch
from torch_scatter import scatter_add
from torch_sparse.sparse import SparseTensor
if torch.cuda.is_available():
import torch_sparse.spmm_cuda
def spmm_(sparse_mat, mat, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max']
assert sparse_mat.dim() == 2 and mat.dim() == 2
assert sparse_mat.size(1) == mat.size(0)
rowptr, col, value = sparse_mat.csr()
mat = mat.contiguous()
if reduce in ['add', 'mean']:
return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
else:
return torch_sparse.spmm_cuda.spmm_arg(rowptr, col, value, mat, reduce)
def spmm(index, value, m, n, matrix):
"""Matrix product of sparse matrix with dense matrix.
......@@ -24,3 +44,60 @@ def spmm(index, value, m, n, matrix):
out = scatter_add(out, row, dim=0, dim_size=m)
return out
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
value = torch.ones_like(col, dtype=torch.float, device=device)
value = None
sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
device=device)
out1 = spmm_(sparse_mat, mat, reduce='add')
out2 = sparse_mat.to_dense() @ mat
assert torch.allclose(out1, out2)
from torch_geometric.datasets import Reddit, Planetoid # noqa
import time # noqa
# Warmup
x = torch.randn((1000, 1000), device=device)
for _ in range(100):
x.sum()
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed')
# dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device)
mat = torch.randn((data.num_nodes, 1024), device=device)
value = torch.ones(data.num_edges, device=device)
sparse_mat = SparseTensor(data.edge_index, value)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out1 = spmm_(sparse_mat, mat, reduce='add')
out1 = out1[0] if isinstance(out1, tuple) else out1
torch.cuda.synchronize()
print('My: ', time.perf_counter() - t)
sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
sparse_mat = sparse_mat.coalesce()
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out2 = sparse_mat @ mat
torch.cuda.synchronize()
print('Torch: ', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
torch.cuda.synchronize()
print('Scatter:', time.perf_counter() - t)
assert torch.allclose(out1, out2, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment