Commit 43b284f1 authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 7636e1d1
......@@ -53,19 +53,24 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
auto rowptrC_data = rowptrC.data_ptr<int64_t>();
rowptrC_data[0] = 0;
int64_t rowA_start = 0, rowA_end, rowB_start, rowB_end, cA, cB;
int64_t nnz = 0, row_nnz;
for (auto n = 1; n < rowptrA.numel(); n++) {
rowA_end = rowptrA_data[n];
std::vector<int64_t> mask(K, -1);
int64_t nnz = 0, row_nnz, rowA_start, rowA_end, rowB_start, rowB_end, cA, cB;
for (auto n = 0; n < rowptrA.numel() - 1; n++) {
row_nnz = 0;
for (auto eA = rowA_start; eA < rowA_end; eA++) {
for (auto eA = rowptrA_data[n]; eA < rowptrA_data[n + 1]; eA++) {
cA = colA_data[eA];
row_nnz = rowptrB_data[cA + 1] - rowptrB_data[cA];
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
cB = colB_data[eB];
if (mask[cB] != n) {
mask[cB] = n;
row_nnz++;
}
}
}
nnz += row_nnz;
rowptrC_data[n] = nnz;
rowA_start = rowA_end;
rowptrC_data[n + 1] = nnz;
}
// Pass 2: Compute CSR entries.
......
......@@ -27,44 +27,6 @@
} \
}()
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
......@@ -108,7 +70,6 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle();
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST);
cusparseMatDescr_t descr;
cusparseCreateMatDescr(&descr);
......
import torch
from torch_sparse import SparseStorage, SparseTensor
from typing import Dict, Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@torch.jit.script
class Foo:
rowptr: torch.Tensor
col: torch.Tensor
def __init__(self, rowptr: torch.Tensor, col: torch.Tensor):
self.rowptr = rowptr
self.col = col
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseTensor) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.storage.rowptr(),
adj.storage.col(), None, x,
'sum')
return out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def test_jit():
my_cell = MyCell()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x = torch.randn(3, 2)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr = torch.tensor([0, 1, 4, 7])
col = torch.tensor([0, 0, 1, 2, 0, 1, 2])
adj = SparseTensor(rowptr=rowptr, col=col)
# scipy = adj.to_scipy(layout='csr')
# mat = SparseTensor.from_scipy(scipy)
print()
# adj = t(adj)
adj = adj.t()
adj = adj.remove_diag(k=0)
print(adj.to_dense())
adj = adj + torch.tensor([1, 2, 3]).view(1, 3)
print(adj)
print(adj.to_dense())
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
out = traced_cell(x, adj)
print(out)
# # print(traced_cell.code)
......@@ -4,32 +4,16 @@ import pytest
import torch
from torch_sparse import spspmm
from .utils import dtypes, devices, tensor
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
sizeA = torch.Size([3, 3])
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device)
sizeB = torch.Size([3, 2])
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]
A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
A = A.to_dense().requires_grad_()
B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
B = B.to_dense().requires_grad_()
torch.matmul(A, B).sum().backward()
valueA = valueA.requires_grad_()
valueB = valueB.requires_grad_()
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
valueC.sum().backward()
assert valueA.grad.tolist() == A.grad[indexA[0], indexA[1]].tolist()
assert valueB.grad.tolist() == B.grad[indexB[0], indexB[1]].tolist()
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_spspmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, 3, x)
assert out.size() == (3, 2)
import copy
from itertools import product
import pytest
......@@ -13,18 +12,18 @@ def test_storage(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value is None
assert storage.sparse_size == (2, 2)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value() is None
assert storage.sparse_sizes() == (2, 2)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.sparse_size == (2, 2)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
assert storage.sparse_sizes() == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......@@ -41,7 +40,7 @@ def test_caching(dtype, device):
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.cached_keys() == []
assert storage.num_cached_keys() == 0
storage.fill_cache_()
assert storage._rowcount.tolist() == [2, 2]
......@@ -50,16 +49,14 @@ def test_caching(dtype, device):
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
]
assert storage.num_cached_keys() == 5
storage = SparseStorage(row=row, rowptr=storage.rowptr, col=col,
value=storage.value,
sparse_size=storage.sparse_size,
rowcount=storage.rowcount, colptr=storage.colptr,
colcount=storage.colcount, csr2csc=storage.csr2csc,
csc2csr=storage.csc2csr)
storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
rowcount=storage._rowcount, colptr=storage._colptr,
colcount=storage._colcount,
csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
......@@ -67,9 +64,7 @@ def test_caching(dtype, device):
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
]
assert storage.num_cached_keys() == 5
storage.clear_cache_()
assert storage._rowcount is None
......@@ -77,7 +72,7 @@ def test_caching(dtype, device):
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.cached_keys() == []
assert storage.num_cached_keys() == 0
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......@@ -89,45 +84,25 @@ def test_utility(dtype, device):
assert storage.has_value()
storage.set_value_(value, layout='csc')
assert storage.value.tolist() == [1, 3, 2, 4]
assert storage.value().tolist() == [1, 3, 2, 4]
storage.set_value_(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.set_value(value, layout='csc')
assert storage.value.tolist() == [1, 3, 2, 4]
assert storage.value().tolist() == [1, 3, 2, 4]
storage = storage.set_value(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize(3, 3)
assert storage.sparse_size == (3, 3)
storage = storage.sparse_resize([3, 3])
assert storage.sparse_sizes() == [3, 3]
new_storage = copy.copy(storage)
new_storage = storage.copy()
assert new_storage != storage
assert new_storage.col.data_ptr() == storage.col.data_ptr()
assert new_storage.col().data_ptr() == storage.col().data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.col.data_ptr() != storage.col.data_ptr()
new_storage = copy.deepcopy(storage)
assert new_storage != storage
assert new_storage.col.data_ptr() != storage.col.data_ptr()
storage.apply_value_(lambda x: x + 1)
assert storage.value.tolist() == [2, 3, 4, 5]
storage = storage.apply_value(lambda x: x + 1)
assert storage.value.tolist() == [3, 4, 5, 6]
storage.apply_(lambda x: x.to(torch.long))
assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long
storage = storage.apply(lambda x: x.to(torch.long))
assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long
storage.clear_cache_()
assert storage.map(lambda x: x.numel()) == [4, 4, 4]
assert new_storage.col().data_ptr() != storage.col().data_ptr()
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......@@ -136,14 +111,14 @@ def test_coalesce(dtype, device):
value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row.tolist() == row.tolist()
assert storage.col.tolist() == col.tolist()
assert storage.value.tolist() == value.tolist()
assert storage.row().tolist() == row.tolist()
assert storage.col().tolist() == col.tolist()
assert storage.value().tolist() == value.tolist()
assert not storage.is_coalesced()
storage = storage.coalesce()
assert storage.is_coalesced()
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import remove_diag, set_diag, fill_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
from .matmul import matmul
from .cat import cat, cat_diag
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce
from .transpose import transpose
......@@ -8,7 +22,33 @@ from .spspmm import spspmm
__version__ = '0.4.3'
__all__ = [
'__version__',
'SparseStorage',
'SparseTensor',
't',
'narrow',
'select',
'index_select',
'index_select_nnz',
'masked_select',
'masked_select_nnz',
'remove_diag',
'set_diag',
'fill_diag',
'add',
'add_',
'add_nnz',
'add_nnz_',
'mul',
'mul_',
'mul_nnz',
'mul_nnz_',
'sum',
'mean',
'min',
'max',
'matmul',
'cat',
'cat_diag',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy',
......@@ -18,19 +58,5 @@ __all__ = [
'eye',
'spmm',
'spspmm',
'__version__',
]
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import remove_diag, set_diag, fill_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
from .matmul import (spmm_sum, spmm_add, spmm_mean, spmm_min, spmm_max, spmm,
spspmm_sum, spspmm_add, spspmm, matmul)
from .cat import cat, cat_diag
import torch
import torch_scatter
from torch_sparse.storage import SparseStorage
# from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0):
def coalesce(index, value, m, n, op="add"):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
......@@ -17,29 +15,11 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
n (int): The second dimension of corresponding dense matrix.
op (string, optional): The scatter operation to use. (default:
:obj:`"add"`)
fill_value (int, optional): The initial fill value of scatter
operation. (default: :obj:`0`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise NotImplementedError
row, col = index
if value is None:
_, perm = unique(row * n + col)
index = torch.stack([row[perm], col[perm]], dim=0)
return index, value
uniq, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uniq.size(0)).scatter_(0, inv, perm)
index = torch.stack([row[perm], col[perm]], dim=0)
op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value)
if isinstance(value, tuple):
value = value[0]
return index, value
storage = SparseStorage(row=index[0], col=index[1], value=value,
sparse_sizes=torch.Size([m, n], is_sorted=False))
storage = storage.coalesce(reduce=op)
return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
# import torch
from torch_scatter import scatter_add
# from torch_sparse.tensor import SparseTensor
# if torch.cuda.is_available():
# import torch_sparse.spmm_cuda
# def spmm_(sparse_mat, mat, reduce='add'):
# assert reduce in ['add', 'mean', 'min', 'max']
# assert sparse_mat.dim() == 2 and mat.dim() == 2
# assert sparse_mat.size(1) == mat.size(0)
# rowptr, col, value = sparse_mat.csr()
# mat = mat.contiguous()
# if reduce in ['add', 'mean']:
# return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
# else:
# return torch_sparse.spmm_cuda.spmm_arg(
# rowptr, col, value, mat, reduce)
def spmm(index, value, m, n, matrix):
"""Matrix product of sparse matrix with dense matrix.
......@@ -44,60 +25,3 @@ def spmm(index, value, m, n, matrix):
out = scatter_add(out, row, dim=0, dim_size=m)
return out
# if __name__ == '__main__':
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
# col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
# value = torch.ones_like(col, dtype=torch.float, device=device)
# value = None
# sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
# mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
# device=device)
# out1 = spmm_(sparse_mat, mat, reduce='add')
# out2 = sparse_mat.to_dense() @ mat
# assert torch.allclose(out1, out2)
# from torch_geometric.datasets import Reddit, Planetoid # noqa
# import time # noqa
# # Warmup
# x = torch.randn((1000, 1000), device=device)
# for _ in range(100):
# x.sum()
# # dataset = Reddit('/tmp/Reddit')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
# # dataset = Planetoid('/tmp/Cora', 'Cora')
# data = dataset[0].to(device)
# mat = torch.randn((data.num_nodes, 1024), device=device)
# value = torch.ones(data.num_edges, device=device)
# sparse_mat = SparseTensor(data.edge_index, value)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = spmm_(sparse_mat, mat, reduce='add')
# out1 = out1[0] if isinstance(out1, tuple) else out1
# torch.cuda.synchronize()
# print('My: ', time.perf_counter() - t)
# sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
# sparse_mat = sparse_mat.coalesce()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out2 = sparse_mat @ mat
# torch.cuda.synchronize()
# print('Torch: ', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
# torch.cuda.synchronize()
# print('Scatter:', time.perf_counter() - t)
# assert torch.allclose(out1, out2, atol=1e-2)
import torch
from torch_sparse import transpose, to_scipy, from_scipy, coalesce
# import torch_sparse.spspmm_cpu
# if torch.cuda.is_available():
# import torch_sparse.spspmm_cuda
from torch_sparse.tensor import SparseTensor
from torch_sparse.matmul import matmul
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
......@@ -25,83 +21,13 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise NotImplementedError
if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n)
index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
return index.detach(), value
class SpSpMM(torch.autograd.Function):
@staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
ctx.m, ctx.k, ctx.n = m, k, n
ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
return indexC, valueC
@staticmethod
def backward(ctx, grad_indexC, grad_valueC):
m, k = ctx.m, ctx.k
n = ctx.n
indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
grad_valueA = grad_valueB = None
if not grad_valueC.is_cuda:
if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
grad_valueC = grad_valueC.clone()
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
indexA, indexC.detach(), grad_valueC, indexB.detach(),
valueB, m, k)
if ctx.needs_input_grad[3]:
indexA, valueA = transpose(indexA, valueA, m, k)
indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
indexB, indexA.detach(), valueA, indexC.detach(),
grad_valueC, k, n)
else:
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
indexA, indexC.detach(), grad_valueC.clone(),
indexB.detach(), valueB, m, k)
if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
return None, grad_valueA, None, grad_valueB, None, None, None
def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype
if indexA.is_cuda:
return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
m, k, n)
A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n)
C = A.dot(B).tocoo().tocsr().tocoo() # Force coalesce.
indexC, valueC = from_scipy(C)
return indexC, valueC
def lift(indexA, valueA, indexB, n): # pragma: no cover
idxA = indexA[0] * n + indexA[1]
idxB = indexB[0] * n + indexB[1]
max_value = max(idxA.max().item(), idxB.max().item()) + 1
valueB = valueA.new_zeros(max_value)
A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
sparse_sizes=torch.Size([m, k]), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=torch.Size([k, n]), is_sorted=not coalesced)
valueB[idxA] = valueA
valueB = valueB[idxB]
C = matmul(A, B)
row, col, value = C.coo()
return valueB
return torch.stack([row, col], dim=0), value
import torch
from torch_sparse import to_scipy, from_scipy, coalesce
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
......@@ -51,14 +50,14 @@ def transpose(index, value, m, n, coalesced=True):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
row, col = col, row
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
sparse_sizes = torch.Size([n, m])
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
row, col, value = storage.row(), storage.col(), storage.value()
return torch.stack([row, col], dim=0), value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment