Commit fa763bac authored by rusty1s's avatar rusty1s
Browse files

add implementation cpu

parent f00ca88b
import time
from itertools import product from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.add import add from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_add(dtype, device): def test_sparse_add(dtype, device):
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device) name = ('DIMACS10', 'citationCiteseer')[1]
mat1 = SparseTensor(index) mat_scipy = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2].tocsr()
mat = SparseTensor.from_scipy(mat_scipy)
mat1 = mat[:, 0:100000]
mat2 = mat[:, 100000:200000]
print(mat1.shape)
print(mat2.shape)
# 0.0159 to beat
t = time.perf_counter()
mat = sparse_add(mat1, mat2)
print(time.perf_counter() - t)
print(mat.nnz())
mat1 = mat_scipy[:, 0:100000]
mat2 = mat_scipy[:, 100000:200000]
t = time.perf_counter()
mat = mat1 + mat2
print(time.perf_counter() - t)
print(mat.nnz)
# mat1 + mat2
# mat1 = mat1.tocoo()
# mat2 = mat2.tocoo()
# row1, col1 = mat1.row, mat1.col
# row2, col2 = mat2.row, mat2.col
# idx1 = row1 * 100000 + col1
# idx2 = row2 * 100000 + col2
# t = time.perf_counter()
# np.union1d(idx1, idx2)
# print(time.perf_counter() - t)
# index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
# mat1 = SparseTensor(index)
# print()
# print(mat1.to_dense())
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device) # index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index) # mat2 = SparseTensor(index)
# print(mat2.to_dense())
add(mat1, mat2) # add(mat1, mat2)
import torch import torch
from torch_scatter import gather_csr
def union(mat1, mat2): def sparse_add(matA, matB):
offset = mat1.nnz() + 1 nnzA, nnzB = matA.nnz(), matB.nnz()
value1 = torch.ones(mat1.nnz(), dtype=torch.long, device=mat2.device) valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
value2 = value1.new_full((mat2.nnz(), ), offset) valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
size = max(mat1.size(0), mat2.size(0)), max(mat1.size(1), mat2.size(1))
if not mat1.is_cuda: if matA.is_cuda:
mat1 = mat1.set_value(value1, layout='coo').to_scipy(layout='csr') pass
mat1.resize(*size) else:
matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
matC_ = matA_ + matB_
rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
matC_ = matC_.tocoo()
row = torch.from_numpy(matC_.row).to(torch.long)
col = torch.from_numpy(matC_.col).to(torch.long)
index = torch.stack([row, col], dim=0)
valC_ = torch.from_numpy(matC_.data)
mat2 = mat2.set_value(value2, layout='coo').to_scipy(layout='csr') value = None
mat2.resize(*size) if matA.has_value() or matB.has_value():
maskA, maskB = valC_ != 2, valC_ >= 2
out = mat1 + mat2 size = matA.size() if matA.dim() >= matB.dim() else matA.size()
rowptr = torch.from_numpy(out.indptr).to(torch.long) size = (valC_.size(0), ) + size[2:]
out = out.tocoo()
row = torch.from_numpy(out.row).to(torch.long) value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
col = torch.from_numpy(out.col).to(torch.long) value[maskA] += matA.storage.value if matA.has_value() else 1
value = torch.from_numpy(out.data) value[maskB] += matB.storage.value if matB.has_value() else 1
else:
raise NotImplementedError
mask1 = value % offset > 0 storage = matA.storage.__class__(index, value, matA.sparse_size(),
mask2 = value >= offset rowptr=rowptr, is_sorted=True)
return rowptr, torch.stack([row, col], dim=0), mask1, mask2 return matA.__class__.from_storage(storage)
def add(src, other): def add(src, other):
...@@ -35,19 +43,21 @@ def add(src, other): ...@@ -35,19 +43,21 @@ def add(src, other):
elif torch.is_tensor(other): elif torch.is_tensor(other):
(row, col), value = src.coo() (row, col), value = src.coo()
if other.size(0) == src.size(0) and other.size(1) == 1: if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
val = other.squeeze(1).repeat_interleave( other = gather_csr(other.squeeze(1), src.storage.rowptr)
row, 0) + (value if src.has_value() else 1) value = other.add_(src.storage.value if src.has_value() else 1)
if other.size(0) == 1 and other.size(1) == src.size(1): return src.set_value(value, layout='csr')
val = other.squeeze(0)[col] + (value if src.has_value() else 1)
else: if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,' other = other.squeeze(0)[col]
f' ...) or (1, {src.size(1)}, ...), but got size ' value = other.add_(src.storage.value if src.has_value() else 1)
f'{other.size()}.') return src.set_value(value, layout='coo')
return src.set_value(val, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__): elif isinstance(other, src.__class__):
rowptr, index, src_offset, other_offset = union(src, other)
raise NotImplementedError raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, ' raise ValueError('Argument `other` needs to be of type `int`, `float`, '
...@@ -55,21 +65,71 @@ def add(src, other): ...@@ -55,21 +65,71 @@ def add(src, other):
def add_(src, other): def add_(src, other):
pass if isinstance(other, int) or isinstance(other, float):
return add_nnz_(src, other)
elif torch.is_tensor(other):
(row, col), value = src.coo()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), src.storage.rowptr)
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_nnz(src, other, layout=None): def add_nnz(src, other, layout=None):
if isinstance(other, int) or isinstance(other, float): if isinstance(other, int) or isinstance(other, float):
return src.set_value(src.storage.value + if src.has_value():
other if src.has_value() else torch.full(( value = src.storage.value + other
src.nnz(), ), 1 + other, device=src.device)) else:
elif torch.is_tensor(other): value = torch.full((src.nnz(), ), 1 + other, device=src.device)
return src.set_value(src.storage.value + return src.set_value(value, layout='coo')
other if src.has_value() else other + 1)
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value + other
else:
value = other + 1
return src.set_value(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or ' raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.') '`torch.tensor`.')
def add_nnz_(src, other, layout=None): def add_nnz_(src, other, layout=None):
pass if isinstance(other, int) or isinstance(other, float):
if src.has_value():
value = src.storage.value.add_(other)
else:
value = torch.full((src.nnz(), ), 1 + other, device=src.device)
return src.set_value_(value, layout='coo')
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other + 1 # No inplace operation possible.
return src.set_value_(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment