Commit a1c268a5 authored by rusty1s's avatar rusty1s
Browse files

fix matmul

parent b6a1f005
import time
import torch import torch
from torch_sparse import to_scipy, from_scipy from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse from torch_sparse import to_torch_sparse, from_torch_sparse
from torch_sparse.storage import SparseStorage
from scipy.io import loadmat
def test_convert_scipy(): def test_convert_scipy():
...@@ -24,37 +21,3 @@ def test_convert_torch_sparse(): ...@@ -24,37 +21,3 @@ def test_convert_torch_sparse():
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce()) out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
assert out[0].tolist() == index.tolist() assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist() assert out[1].tolist() == value.tolist()
def test_ind2ptr():
name = ('DIMACS10', 'citationCiteseer')[1]
mat = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2]
mat = mat.tocsr().tocoo()
mat = mat.tocsr()
rowptr = torch.from_numpy(mat.indptr).to(torch.long).cuda()
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long).cuda()
col = torch.from_numpy(mat.col).to(torch.long).cuda()
storage = SparseStorage(row=row, col=col)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
storage.rowptr
storage._rowptr = None
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert storage.rowptr.tolist() == rowptr.tolist()
storage = SparseStorage(rowptr=rowptr, col=col)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
storage.row
storage._row = None
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert storage.row.tolist() == row.tolist()
...@@ -19,7 +19,7 @@ def test_spmm(dtype, device, reduce): ...@@ -19,7 +19,7 @@ def test_spmm(dtype, device, reduce):
src[2:4, :] = 0 # Remove multiple rows. src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns. src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_() src = SparseTensor.from_dense(src).requires_grad_()
(row, col), value = src.coo() row, col, value = src.coo()
other = torch.randn((2, 8, 2), dtype=dtype, device=device, other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True) requires_grad=True)
......
...@@ -22,7 +22,7 @@ def cat(tensors, dim): ...@@ -22,7 +22,7 @@ def cat(tensors, dim):
if dim == 0: if dim == 0:
for tensor in tensors: for tensor in tensors:
(row, col), value = tensor.coo() row, col, value = tensor.coo()
rows += [row + sparse_size[0]] rows += [row + sparse_size[0]]
cols += [col] cols += [col]
values += [value] values += [value]
...@@ -48,7 +48,7 @@ def cat(tensors, dim): ...@@ -48,7 +48,7 @@ def cat(tensors, dim):
elif dim == 1: elif dim == 1:
for tensor in tensors: for tensor in tensors:
(row, col), value = tensor.coo() row, col, value = tensor.coo()
rows += [row] rows += [row]
cols += [col + sparse_size[1]] cols += [col + sparse_size[1]]
values += [value] values += [value]
...@@ -76,7 +76,7 @@ def cat(tensors, dim): ...@@ -76,7 +76,7 @@ def cat(tensors, dim):
elif dim == (0, 1) or dim == (1, 0): elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors: for tensor in tensors:
(row, col), value = tensor.coo() row, col, value = tensor.coo()
rows += [row + sparse_size[0]] rows += [row + sparse_size[0]]
cols += [col + sparse_size[1]] cols += [col + sparse_size[1]]
values += [value] if has_value else [] values += [value] if has_value else []
......
...@@ -40,24 +40,20 @@ class SPMM(torch.autograd.Function): ...@@ -40,24 +40,20 @@ class SPMM(torch.autograd.Function):
arg_out) = ctx.saved_tensors arg_out) = ctx.saved_tensors
invalid_arg_mask = arg_out_ind = None invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5] if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
or ctx.needs_input_grad[6]): or ctx.needs_input_grad[4]):
invalid_arg_mask = arg_out == row.size(0) invalid_arg_mask = arg_out == col.size(0)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1) arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None grad_value = None
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add']: if ctx.reduce in ['sum', 'add', 'mean']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw( grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce) row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']: elif ctx.reduce in ['min', 'max']:
col = col[arg_out_ind.flatten()].view_as(arg_out) col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out) out = mat.gather(-2, col_tmp).mul_(grad_out)
out.masked_fill_(invalid_arg_mask, 0) out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(), grad_value = scatter_add(out.flatten(), arg_out.flatten(),
dim=0, dim_size=value.numel() + 1) dim=0, dim_size=value.numel() + 1)
...@@ -85,8 +81,8 @@ class SPMM(torch.autograd.Function): ...@@ -85,8 +81,8 @@ class SPMM(torch.autograd.Function):
else: else:
value = grad_out value = grad_out
value.masked_fill_(invalid_arg_mask, 0) value.masked_fill_(invalid_arg_mask, 0)
col = col[arg_out_ind.flatten()].view_as(arg_out) col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2, grad_mat = scatter_add(value, col_tmp, dim=-2,
dim_size=mat.size(-2)) dim_size=mat.size(-2))
return None, None, None, grad_value, grad_mat, None, None, None, None return None, None, None, grad_value, grad_mat, None, None, None, None
...@@ -119,7 +115,7 @@ class SPSPMM(torch.autograd.Function): ...@@ -119,7 +115,7 @@ class SPSPMM(torch.autograd.Function):
rowptrC = torch.from_numpy(C.indptr).to(torch.int64) rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
colC = torch.from_numpy(C.indices).to(torch.int64) colC = torch.from_numpy(C.indices).to(torch.int64)
valueC = torch.from_numpy(C.data) valueC = torch.from_numpy(C.data)
valueC = valueC.to(dtype) if dtype is not None else valueC valueC = valueC.to(dtype) if dtype is not None else None
ctx.mark_non_differentiable(rowptrC, colC) ctx.mark_non_differentiable(rowptrC, colC)
...@@ -152,7 +148,7 @@ def matmul(src, other, reduce='sum'): ...@@ -152,7 +148,7 @@ def matmul(src, other, reduce='sum'):
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
row = None row = None
if reduce in ['sum', 'add'] and (src.requires_grad if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
or other.reuqires_grad): or other.reuqires_grad):
row = src.storage.row row = src.storage.row
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment