Commit a1c268a5 authored by rusty1s's avatar rusty1s
Browse files

fix matmul

parent b6a1f005
import time
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse
from torch_sparse.storage import SparseStorage
from scipy.io import loadmat
def test_convert_scipy():
......@@ -24,37 +21,3 @@ def test_convert_torch_sparse():
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
def test_ind2ptr():
name = ('DIMACS10', 'citationCiteseer')[1]
mat = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2]
mat = mat.tocsr().tocoo()
mat = mat.tocsr()
rowptr = torch.from_numpy(mat.indptr).to(torch.long).cuda()
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long).cuda()
col = torch.from_numpy(mat.col).to(torch.long).cuda()
storage = SparseStorage(row=row, col=col)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
storage.rowptr
storage._rowptr = None
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert storage.rowptr.tolist() == rowptr.tolist()
storage = SparseStorage(rowptr=rowptr, col=col)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
storage.row
storage._row = None
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert storage.row.tolist() == row.tolist()
......@@ -19,7 +19,7 @@ def test_spmm(dtype, device, reduce):
src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_()
(row, col), value = src.coo()
row, col, value = src.coo()
other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True)
......
......@@ -22,7 +22,7 @@ def cat(tensors, dim):
if dim == 0:
for tensor in tensors:
(row, col), value = tensor.coo()
row, col, value = tensor.coo()
rows += [row + sparse_size[0]]
cols += [col]
values += [value]
......@@ -48,7 +48,7 @@ def cat(tensors, dim):
elif dim == 1:
for tensor in tensors:
(row, col), value = tensor.coo()
row, col, value = tensor.coo()
rows += [row]
cols += [col + sparse_size[1]]
values += [value]
......@@ -76,7 +76,7 @@ def cat(tensors, dim):
elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors:
(row, col), value = tensor.coo()
row, col, value = tensor.coo()
rows += [row + sparse_size[0]]
cols += [col + sparse_size[1]]
values += [value] if has_value else []
......
......@@ -40,24 +40,20 @@ class SPMM(torch.autograd.Function):
arg_out) = ctx.saved_tensors
invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
or ctx.needs_input_grad[6]):
invalid_arg_mask = arg_out == row.size(0)
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
or ctx.needs_input_grad[4]):
invalid_arg_mask = arg_out == col.size(0)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None
if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean':
if ctx.reduce in ['sum', 'add', 'mean']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out)
col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col_tmp).mul_(grad_out)
out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(),
dim=0, dim_size=value.numel() + 1)
......@@ -85,8 +81,8 @@ class SPMM(torch.autograd.Function):
else:
value = grad_out
value.masked_fill_(invalid_arg_mask, 0)
col = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2,
col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col_tmp, dim=-2,
dim_size=mat.size(-2))
return None, None, None, grad_value, grad_mat, None, None, None, None
......@@ -119,7 +115,7 @@ class SPSPMM(torch.autograd.Function):
rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
colC = torch.from_numpy(C.indices).to(torch.int64)
valueC = torch.from_numpy(C.data)
valueC = valueC.to(dtype) if dtype is not None else valueC
valueC = valueC.to(dtype) if dtype is not None else None
ctx.mark_non_differentiable(rowptrC, colC)
......@@ -152,8 +148,8 @@ def matmul(src, other, reduce='sum'):
rowptr, col, value = src.csr()
row = None
if reduce in ['sum', 'add'] and (src.requires_grad
or other.reuqires_grad):
if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
or other.reuqires_grad):
row = src.storage.row
rowcount = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment