Commit 2554bf09 authored by rusty1s's avatar rusty1s
Browse files

all select methods

parent 4a569c27
#include <torch/extension.h>
#include "compat.h"
at::Tensor arange_interleave(at::Tensor start, at::Tensor repeat) {
auto count = repeat.sum().DATA_PTR<int64_t>()[0];
auto out = at::empty(count, start.options());
auto repeat_data = repeat.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(start.scalar_type(), "arange_interleave", [&] {
auto start_data = start.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
int i = 0;
for (int start_idx = 0; start_idx < start.size(0); start_idx++) {
scalar_t init = start_data[start_idx];
for (scalar_t rep_idx = 0; rep_idx < repeat_data[start_idx]; rep_idx++) {
out_data[i] = init + rep_idx;
i++;
}
}
});
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("arange_interleave", &arange_interleave, "Arange Interleave (CPU)");
}
...@@ -12,8 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): ...@@ -12,8 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3'] extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension('torch_sparse.arange_interleave_cpu',
['cpu/arange_interleave.cpp'],
extra_compile_args=extra_compile_args),
CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'], CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'],
extra_compile_args=extra_compile_args) extra_compile_args=extra_compile_args),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
......
import torch
from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu
def __arange_interleave__(start, repeat):
assert start.device == repeat.device
assert repeat.dtype == torch.long
assert start.dim() == 1
assert repeat.dim() == 1
assert start.numel() == repeat.numel()
if start.is_cuda:
raise NotImplementedError
return arange_interleave_cpu.arange_interleave(start, repeat)
def index_select(src, dim, idx):
dim = src.dim() - dim if dim < 0 else dim
assert idx.dim() == 1
idx = idx.to(src.device)
if dim == 0:
(_, col), value = src.coo()
rowcount = src.storage.rowcount
rowptr = src.storage.rowptr
rowcount = rowcount[idx]
tmp = torch.arange(rowcount.size(0), device=rowcount.device)
row = tmp.repeat_interleave(rowcount)
perm = __arange_interleave__(rowptr[idx], rowcount)
col = col[perm]
index = torch.stack([row, col], dim=0)
if src.has_value():
value = value[perm]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size,
rowcount=rowcount, is_sorted=True)
elif dim == 1:
colptr, row, value = src.csc()
colcount = src.storage.colcount
colcount = colcount[idx]
tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount)
perm = __arange_interleave__(colptr[idx], colcount)
row = row[perm]
csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr]
if src.has_value():
value = value[perm][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(index, value, sparse_size,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
else:
storage = src.storage.apply_value(
lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage)
def index_select_nnz(src, idx, layout=None):
assert idx.dim() == 1
if get_layout(layout) == 'csc':
idx = idx[src.storage.csc2csr]
index, value = src.coo()
index = index[:, idx]
if src.has_value():
value = value[idx]
# There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(),
is_sorted=True)
return src.from_storage(storage)
import torch
from torch_sparse.storage import get_layout
def masked_select(src, dim, mask):
dim = src.dim() - dim if dim < 0 else dim
assert mask.dim() == 1
storage = src.storage
if dim == 0:
(row, col), value = src.coo()
rowcount = src.storage.rowcount
row_mask = mask[row]
rowcount = rowcount[mask]
idx = torch.arange(rowcount.size(0), device=rowcount.device)
row = idx.repeat_interleave(rowcount)
col = col[row_mask]
index = torch.stack([row, col], dim=0)
if src.has_value():
value = value[row_mask]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size,
rowcount=rowcount, is_sorted=True)
elif dim == 1:
csr2csc = src.storage.csr2csc
row = src.storage.row[csr2csc]
col = src.storage.col[csr2csc]
colcount = src.storage.colcount
col_mask = mask[col]
colcount = colcount[mask]
tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount)
row = row[col_mask]
csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr]
value = src.storage.value
if src.has_value():
value = value[csr2csc][col_mask][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(index, value, sparse_size,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
else:
idx = mask.nonzero().view(-1)
storage = src.storage.apply_value(
lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage)
def masked_select_nnz(src, mask, layout=None):
assert mask.dim() == 1
if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr]
index, value = src.coo()
index = index[:, mask]
if src.has_value():
value = value[mask]
# There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(),
is_sorted=True)
return src.from_storage(storage)
...@@ -2,9 +2,16 @@ import torch ...@@ -2,9 +2,16 @@ import torch
def narrow(src, dim, start, length): def narrow(src, dim, start, length):
dim = src.dim() - dim if dim < 0 else dim
if dim == 0: if dim == 0:
(row, col), value = src.coo() (row, col), value = src.coo()
rowptr, _, _ = src.csr() rowptr = src.storage.rowptr
# Maintain `rowcount`...
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
rowptr = rowptr.narrow(0, start=start, length=length + 1) rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0] row_start = rowptr[0]
...@@ -18,15 +25,22 @@ def narrow(src, dim, start, length): ...@@ -18,15 +25,22 @@ def narrow(src, dim, start, length):
value = value.narrow(0, row_start, row_length) value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)]) sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src._storage.__class__( storage = src.storage.__class__(index, value, sparse_size,
index, value, sparse_size, rowptr=rowptr, is_sorted=True) rowcount=rowcount, rowptr=rowptr,
is_sorted=True)
elif dim == 1: elif dim == 1:
# This is faster than accessing `csc()` in analogy to the `dim=0` case. # This is faster than accessing `csc()` contrary to the `dim=0` case.
(row, col), value = src.coo() (row, col), value = src.coo()
mask = (col >= start) & (col < start + length) mask = (col >= start) & (col < start + length)
colptr = src._storage._colptr # Maintain `colcount`...
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
# Maintain `colptr`...
colptr = src.storage._colptr
if colptr is not None: if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0] colptr = colptr - colptr[0]
...@@ -36,11 +50,12 @@ def narrow(src, dim, start, length): ...@@ -36,11 +50,12 @@ def narrow(src, dim, start, length):
value = value[mask] value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length]) sparse_size = torch.Size([src.sparse_size(0), length])
storage = src._storage.__class__( storage = src.storage.__class__(index, value, sparse_size,
index, value, sparse_size, colptr=colptr, is_sorted=True) colcount=colcount, colptr=colptr,
is_sorted=True)
else: else:
storage = src._storage.apply_value(lambda x: x.narrow( storage = src.storage.apply_value(
dim - 1, start, length)) lambda x: x.narrow(dim - 1, start, length))
return src.__class__.from_storage(storage) return src.from_storage(storage)
def select(src, dim, idx):
return src.narrow(dim, start=idx, length=1)
...@@ -20,19 +20,26 @@ class cached_property(object): ...@@ -20,19 +20,26 @@ class cached_property(object):
return value return value
layouts = ['coo', 'csr', 'csc']
def get_layout(layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in layouts
return layout
class SparseStorage(object): class SparseStorage(object):
layouts = ['coo', 'csr', 'csc'] cache_keys = [
cache_keys = ['rowptr', 'colptr', 'csr_to_csc', 'csc_to_csr'] 'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self,
index, def __init__(self, index, value=None, sparse_size=None, rowcount=None,
value=None, rowptr=None, colcount=None, colptr=None, csr2csc=None,
sparse_size=None, csc2csr=None, is_sorted=False):
rowptr=None,
colptr=None,
csr_to_csc=None,
csc_to_csr=None,
is_sorted=False):
assert index.dtype == torch.long assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
...@@ -46,25 +53,37 @@ class SparseStorage(object): ...@@ -46,25 +53,37 @@ class SparseStorage(object):
if sparse_size is None: if sparse_size is None:
sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist()) sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == index.device
assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]
if rowptr is not None: if rowptr is not None:
assert rowptr.dtype == torch.long and rowptr.device == index.device assert rowptr.dtype == torch.long
assert rowptr.device == index.device
assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0] assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == index.device
assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]
if colptr is not None: if colptr is not None:
assert colptr.dtype == torch.long and colptr.device == index.device assert colptr.dtype == torch.long
assert colptr.device == index.device
assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1] assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
if csr_to_csc is not None: if csr2csc is not None:
assert csr_to_csc.dtype == torch.long assert csr2csc.dtype == torch.long
assert csr_to_csc.device == index.device assert csr2csc.device == index.device
assert csr_to_csc.dim() == 1 assert csr2csc.dim() == 1
assert csr_to_csc.numel() == index.size(1) assert csr2csc.numel() == index.size(1)
if csc_to_csr is not None: if csc2csr is not None:
assert csc_to_csr.dtype == torch.long assert csc2csr.dtype == torch.long
assert csc_to_csr.device == index.device assert csc2csr.device == index.device
assert csc_to_csr.dim() == 1 assert csc2csr.dim() == 1
assert csc_to_csr.numel() == index.size(1) assert csc2csr.numel() == index.size(1)
if not is_sorted: if not is_sorted:
idx = sparse_size[1] * index[0] + index[1] idx = sparse_size[1] * index[0] + index[1]
...@@ -73,18 +92,18 @@ class SparseStorage(object): ...@@ -73,18 +92,18 @@ class SparseStorage(object):
perm = idx.argsort() perm = idx.argsort()
index = index[:, perm] index = index[:, perm]
value = None if value is None else value[perm] value = None if value is None else value[perm]
rowptr = None csr2csc = None
colptr = None csc2csr = None
csr_to_csc = None
csc_to_csr = None
self._index = index self._index = index
self._value = value self._value = value
self._sparse_size = sparse_size self._sparse_size = sparse_size
self._rowcount = rowcount
self._rowptr = rowptr self._rowptr = rowptr
self._colcount = colcount
self._colptr = colptr self._colptr = colptr
self._csr_to_csc = csr_to_csc self._csr2csc = csr2csc
self._csc_to_csr = csc_to_csr self._csc2csr = csc2csr
@property @property
def index(self): def index(self):
...@@ -106,27 +125,17 @@ class SparseStorage(object): ...@@ -106,27 +125,17 @@ class SparseStorage(object):
return self._value return self._value
def set_value_(self, value, layout=None): def set_value_(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in self.layouts
assert value.device == self._index.device assert value.device == self._index.device
assert value.size(0) == self._index.size(1) assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc': if value is not None and get_layout(layout) == 'csc':
value = value[self.csc_to_csr] value = value[self.csc2csr]
return self.apply_value_(lambda x: value) return self.apply_value_(lambda x: value)
def set_value(self, value, layout=None): def set_value(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in self.layouts
assert value.device == self._index.device assert value.device == self._index.device
assert value.size(0) == self._index.size(1) assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc': if value is not None and get_layout(layout) == 'csc':
value = value[self.csc_to_csr] value = value[self.csc2csr]
return self.apply_value(lambda x: value) return self.apply_value(lambda x: value)
def sparse_size(self, dim=None): def sparse_size(self, dim=None):
...@@ -137,28 +146,34 @@ class SparseStorage(object): ...@@ -137,28 +146,34 @@ class SparseStorage(object):
self._sparse_size == sizes self._sparse_size == sizes
return self return self
@cached_property
def rowcount(self):
one = torch.ones_like(self.row)
return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])
@cached_property @cached_property
def rowptr(self): def rowptr(self):
row = self.row rowcount = self.rowcount
ones = torch.ones_like(row) return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
out_deg = segment_add(ones, row, dim=0, dim_size=self._sparse_size[0])
return torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0) @cached_property
def colcount(self):
one = torch.ones_like(self.col)
return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
@cached_property @cached_property
def colptr(self): def colptr(self):
col = self.col colcount = self.colcount
ones = torch.ones_like(col) return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0)
in_deg = scatter_add(ones, col, dim=0, dim_size=self._sparse_size[1])
return torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
@cached_property @cached_property
def csr_to_csc(self): def csr2csc(self):
idx = self._sparse_size[0] * self.col + self.row idx = self._sparse_size[0] * self.col + self.row
return idx.argsort() return idx.argsort()
@cached_property @cached_property
def csc_to_csr(self): def csc2csr(self):
return self.csr_to_csc.argsort() return self.csr2csc.argsort()
def is_coalesced(self): def is_coalesced(self):
raise NotImplementedError raise NotImplementedError
...@@ -202,10 +217,12 @@ class SparseStorage(object): ...@@ -202,10 +217,12 @@ class SparseStorage(object):
self._index, self._index,
optional(func, self._value), optional(func, self._value),
self._sparse_size, self._sparse_size,
self._rowcount,
self._rowptr, self._rowptr,
self._colcount,
self._colptr, self._colptr,
self._csr_to_csc, self._csr2csc,
self._csc_to_csr, self._csc2csr,
is_sorted=True, is_sorted=True,
) )
...@@ -221,10 +238,12 @@ class SparseStorage(object): ...@@ -221,10 +238,12 @@ class SparseStorage(object):
func(self._index), func(self._index),
optional(func, self._value), optional(func, self._value),
self._sparse_size, self._sparse_size,
optional(func, self._rowcount),
optional(func, self._rowptr), optional(func, self._rowptr),
optional(func, self._colcount),
optional(func, self._colptr), optional(func, self._colptr),
optional(func, self._csr_to_csc), optional(func, self._csr2csc),
optional(func, self._csc_to_csr), optional(func, self._csc2csr),
is_sorted=True, is_sorted=True,
) )
......
...@@ -3,21 +3,24 @@ from textwrap import indent ...@@ -3,21 +3,24 @@ from textwrap import indent
import torch import torch
import scipy.sparse import scipy.sparse
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.transpose import t from torch_sparse.transpose import t
from torch_sparse.narrow import narrow from torch_sparse.narrow import narrow
from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self._storage = SparseStorage( self.storage = SparseStorage(index, value, sparse_size,
index, value, sparse_size, is_sorted=is_sorted) is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage): def from_storage(self, storage):
self = SparseTensor.__new__(SparseTensor) self = SparseTensor.__new__(SparseTensor)
self._storage = storage self.storage = storage
return self return self
@classmethod @classmethod
...@@ -32,10 +35,10 @@ class SparseTensor(object): ...@@ -32,10 +35,10 @@ class SparseTensor(object):
return self.__class__(index, value, mat.size()[:2], is_sorted=True) return self.__class__(index, value, mat.size()[:2], is_sorted=True)
def __copy__(self): def __copy__(self):
return self.__class__.from_storage(self._storage) return self.from_storage(self.storage)
def clone(self): def clone(self):
return self.__class__.from_storage(self._storage.clone()) return self.from_storage(self.storage.clone())
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
new_sparse_tensor = self.clone() new_sparse_tensor = self.clone()
...@@ -45,58 +48,57 @@ class SparseTensor(object): ...@@ -45,58 +48,57 @@ class SparseTensor(object):
# Formats ################################################################# # Formats #################################################################
def coo(self): def coo(self):
return self._storage.index, self._storage.value return self.storage.index, self.storage.value
def csr(self): def csr(self):
return self._storage.rowptr, self._storage.col, self._storage.value return self.storage.rowptr, self.storage.col, self.storage.value
def csc(self): def csc(self):
perm = self._storage.csr_to_csc perm = self.storage.csr2csc
return (self._storage.colptr, self._storage.row[perm], return (self.storage.colptr, self.storage.row[perm],
self._storage.value[perm] if self.has_value() else None) self.storage.value[perm] if self.has_value() else None)
# Storage inheritance ##################################################### # Storage inheritance #####################################################
def has_value(self): def has_value(self):
return self._storage.has_value() return self.storage.has_value()
def set_value_(self, value, layout=None): def set_value_(self, value, layout=None):
self._storage.set_value_(value, layout) self.storage.set_value_(value, layout)
return self return self
def set_value(self, value, layout=None): def set_value(self, value, layout=None):
storage = self._storage.set_value(value, layout) return self.from_storage(self.storage.set_value(value, layout))
return self.__class__.from_storage(storage)
def sparse_size(self, dim=None): def sparse_size(self, dim=None):
return self._storage.sparse_size(dim) return self.storage.sparse_size(dim)
def sparse_resize_(self, *sizes): def sparse_resize_(self, *sizes):
self._storage.sparse_resize_(*sizes) self.storage.sparse_resize_(*sizes)
return self return self
def is_coalesced(self): def is_coalesced(self):
return self._storage.is_coalesced() return self.storage.is_coalesced()
def coalesce(self): def coalesce(self):
return self.__class__.from_storage(self._storage.coalesce()) return self.from_storage(self.storage.coalesce())
def cached_keys(self): def cached_keys(self):
return self._storage.cached_keys() return self.storage.cached_keys()
def fill_cache_(self, *args): def fill_cache_(self, *args):
self._storage.fill_cache_(*args) self.storage.fill_cache_(*args)
return self return self
def clear_cache_(self, *args): def clear_cache_(self, *args):
self._storage.clear_cache_(*args) self.storage.clear_cache_(*args)
return self return self
# Utility functions ####################################################### # Utility functions #######################################################
def size(self, dim=None): def size(self, dim=None):
size = self.sparse_size() size = self.sparse_size()
size += self._storage.value.size()[1:] if self.has_value() else () size += self.storage.value.size()[1:] if self.has_value() else ()
return size if dim is None else size[dim] return size if dim is None else size[dim]
def dim(self): def dim(self):
...@@ -107,7 +109,7 @@ class SparseTensor(object): ...@@ -107,7 +109,7 @@ class SparseTensor(object):
return self.size() return self.size()
def nnz(self): def nnz(self):
return self._storage.index.size(1) return self.storage.index.size(1)
def density(self): def density(self):
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1)) return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
...@@ -138,50 +140,47 @@ class SparseTensor(object): ...@@ -138,50 +140,47 @@ class SparseTensor(object):
return index_sym.item() and value_sym return index_sym.item() and value_sym
def detach_(self): def detach_(self):
self._storage.apply_(lambda x: x.detach_()) self.storage.apply_(lambda x: x.detach_())
return self return self
def detach(self): def detach(self):
storage = self._storage.apply(lambda x: x.detach()) return self.from_storage(self.storage.apply(lambda x: x.detach()))
return self.__class__.from_storage(storage)
def pin_memory(self): def pin_memory(self):
storage = self._storage.apply(lambda x: x.pin_memory()) return self.from_storage(self.storage.apply(lambda x: x.pin_memory()))
return self.__class__.from_storage(storage)
def is_pinned(self): def is_pinned(self):
return all(self._storage.map(lambda x: x.is_pinned())) return all(self.storage.map(lambda x: x.is_pinned()))
def share_memory_(self): def share_memory_(self):
self._storage.apply_(lambda x: x.share_memory_()) self.storage.apply_(lambda x: x.share_memory_())
return self return self
def is_shared(self): def is_shared(self):
return all(self._storage.map(lambda x: x.is_shared())) return all(self.storage.map(lambda x: x.is_shared()))
@property @property
def device(self): def device(self):
return self._storage.index.device return self.storage.index.device
def cpu(self): def cpu(self):
storage = self._storage.apply(lambda x: x.cpu()) return self.from_storage(self.storage.apply(lambda x: x.cpu()))
return self.__class__.from_storage(storage)
def cuda(self, device=None, non_blocking=False, **kwargs): def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self._storage.apply(lambda x: x.cuda(device, non_blocking, ** storage = self.storage.apply(
kwargs)) lambda x: x.cuda(device, non_blocking, **kwargs))
return self.__class__.from_storage(storage) return self.from_storage(storage)
@property @property
def is_cuda(self): def is_cuda(self):
return self._storage.index.is_cuda return self.storage.index.is_cuda
@property @property
def dtype(self): def dtype(self):
return self._storage.value.dtype if self.has_value() else None return self.storage.value.dtype if self.has_value() else None
def is_floating_point(self): def is_floating_point(self):
value = self._storage.value value = self.storage.value
return self.has_value() and torch.is_floating_point(value) return self.has_value() and torch.is_floating_point(value)
def type(self, dtype=None, non_blocking=False, **kwargs): def type(self, dtype=None, non_blocking=False, **kwargs):
...@@ -191,9 +190,10 @@ class SparseTensor(object): ...@@ -191,9 +190,10 @@ class SparseTensor(object):
if dtype == self.dtype: if dtype == self.dtype:
return self return self
storage = self._storage.apply_value(lambda x: x.type( storage = self.storage.apply_value(
dtype, non_blocking, **kwargs)) lambda x: x.type(dtype, non_blocking, **kwargs))
return self.__class__.from_storage(storage)
return self.from_storage(storage)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
storage = None storage = None
...@@ -201,17 +201,17 @@ class SparseTensor(object): ...@@ -201,17 +201,17 @@ class SparseTensor(object):
if 'device' in kwargs: if 'device' in kwargs:
device = kwargs['device'] device = kwargs['device']
del kwargs['device'] del kwargs['device']
storage = self._storage.apply(lambda x: x.to( storage = self.storage.apply(lambda x: x.to(
device, non_blocking=getattr(kwargs, 'non_blocking', False))) device, non_blocking=getattr(kwargs, 'non_blocking', False)))
for arg in args[:]: for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device): if isinstance(arg, str) or isinstance(arg, torch.device):
storage = self._storage.apply(lambda x: x.to( storage = self.storage.apply(lambda x: x.to(
arg, non_blocking=getattr(kwargs, 'non_blocking', False))) arg, non_blocking=getattr(kwargs, 'non_blocking', False)))
args.remove(arg) args.remove(arg)
if storage is not None: if storage is not None:
self = self.__class__.from_storage(storage) self = self.from_storage(storage)
if len(args) > 0 or len(kwargs) > 0: if len(args) > 0 or len(kwargs) > 0:
self = self.type(*args, **kwargs) self = self.type(*args, **kwargs)
...@@ -260,16 +260,13 @@ class SparseTensor(object): ...@@ -260,16 +260,13 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo() index, value = self.coo()
return torch.sparse_coo_tensor( return torch.sparse_coo_tensor(
index, index, value if self.has_value() else torch.ones(
value if self.has_value() else torch.ones( self.nnz(), dtype=dtype, device=self.device), self.size(),
self.nnz(), dtype=dtype, device=self.device), device=self.device, requires_grad=requires_grad)
self.size(),
device=self.device, def to_scipy(self, dtype=None, layout=None):
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout='coo'):
assert self.dim() == 2 assert self.dim() == 2
assert layout in self._storage.layouts layout = get_layout(layout)
if not self.has_value(): if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy() ones = torch.ones(self.nnz(), dtype=dtype).numpy()
...@@ -318,33 +315,20 @@ class SparseTensor(object): ...@@ -318,33 +315,20 @@ class SparseTensor(object):
SparseTensor.t = t SparseTensor.t = t
SparseTensor.narrow = narrow SparseTensor.narrow = narrow
SparseTensor.select = select
# def set_diag(self, value): SparseTensor.index_select = index_select
# raise NotImplementedError SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
# def masked_select(self, mask): SparseTensor.masked_select_nnz = masked_select_nnz
# raise NotImplementedError
# def index_select(self, index):
# raise NotImplementedError
# def select(self, dim, index):
# raise NotImplementedError
# def filter(self, index):
# assert self.is_symmetric
# assert index.dtype == torch.long or index.dtype == torch.bool
# raise NotImplementedError
# def permute(self, index):
# assert index.dtype == torch.long
# return self.filter(index)
# def __getitem__(self, idx): # def __getitem__(self, idx):
# # Convert int and slice to index tensor # # Convert int and slice to index tensor
# # Filter list into edge and sparse slice # # Filter list into edge and sparse slice
# raise NotImplementedError # raise NotImplementedError
# def set_diag(self, value):
# raise NotImplementedError
# def __reduce(self, dim, reduce, only_nnz): # def __reduce(self, dim, reduce, only_nnz):
# raise NotImplementedError # raise NotImplementedError
...@@ -388,7 +372,7 @@ SparseTensor.narrow = narrow ...@@ -388,7 +372,7 @@ SparseTensor.narrow = narrow
# '"coo". This may lead to unexpected behaviour.') # '"coo". This may lead to unexpected behaviour.')
# assert layout in ['coo', 'csr', 'csc'] # assert layout in ['coo', 'csr', 'csc']
# if layout == 'csc': # if layout == 'csc':
# other = other[self._arg_csc_to_csr] # other = other[self._arg_csc2csr]
# if self.has_value: # if self.has_value:
# return self.set_value(self._value + other, 'coo') # return self.set_value(self._value + other, 'coo')
# else: # else:
...@@ -440,10 +424,61 @@ if __name__ == '__main__': ...@@ -440,10 +424,61 @@ if __name__ == '__main__':
dataset = Planetoid('/tmp/Cora', 'Cora') dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device) data = dataset[0].to(device)
value = torch.randn((data.num_edges, ), device=device) value = torch.randn((data.num_edges, 10), device=device)
mat1 = SparseTensor(data.edge_index, value) mat1 = SparseTensor(data.edge_index, value)
index = torch.tensor([0, 2])
mat2 = mat1.index_select(2, index)
index = torch.randperm(data.num_nodes)[:data.num_nodes - 500]
mask = torch.zeros(data.num_nodes, dtype=torch.bool)
mask[index] = True
t = time.perf_counter()
for _ in range(1000):
mat2 = mat1.index_select(0, index)
print(time.perf_counter() - t)
t = time.perf_counter()
for _ in range(1000):
mat2 = mat1.masked_select(0, mask)
print(time.perf_counter() - t)
# mat2 = mat1.narrow(1, start=0, length=3)
# print(mat2)
# index = torch.randperm(data.num_nodes)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(0, index)
# print(time.perf_counter() - t)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(1, index)
# print(time.perf_counter() - t)
# raise NotImplementedError
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.t().index_select(0, index).t()
# print(time.perf_counter() - t)
# print(mat1) # print(mat1)
# mat1.index_select((0, 1), torch.tensor([0, 1, 2, 3, 4]))
# print(mat3)
# print(mat3.storage.rowcount)
# print(mat1)
# (row, col), value = mat1.coo()
# mask = row < 3
# t = time.perf_counter()
# for _ in range(10000):
# mat2 = mat1.narrow(1, start=10, length=2690)
# print(time.perf_counter() - t)
# # print(mat1.to_dense().size()) # # print(mat1.to_dense().size())
# print(mat1.to_torch_sparse_coo_tensor().to_dense().size()) # print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
...@@ -461,24 +496,24 @@ if __name__ == '__main__': ...@@ -461,24 +496,24 @@ if __name__ == '__main__':
# print(mat1.cached_keys()) # print(mat1.cached_keys())
# print('-------- NARROW ----------') # print('-------- NARROW ----------')
t = time.perf_counter() # t = time.perf_counter()
for _ in range(100): # for _ in range(100):
out = mat1.narrow(dim=0, start=10, length=10) # out = mat1.narrow(dim=0, start=10, length=10)
# out._storage.colptr # # out.storage.colptr
print(time.perf_counter() - t) # print(time.perf_counter() - t)
print(out) # print(out)
print(out.cached_keys()) # print(out.cached_keys())
t = time.perf_counter() # t = time.perf_counter()
for _ in range(100): # for _ in range(100):
out = mat1.narrow(dim=1, start=10, length=2000) # out = mat1.narrow(dim=1, start=10, length=2000)
# out._storage.colptr # # out.storage.colptr
print(time.perf_counter() - t) # print(time.perf_counter() - t)
print(out) # print(out)
print(out.cached_keys()) # print(out.cached_keys())
# mat1 = mat1.narrow(0, start=10, length=10) # mat1 = mat1.narrow(0, start=10, length=10)
# mat1._storage._value = torch.randn(mat1.nnz(), 20) # mat1.storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size()) # print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10) # mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size()) # print(mat1.coo()[1].size())
......
...@@ -28,18 +28,21 @@ def transpose(index, value, m, n, coalesced=True): ...@@ -28,18 +28,21 @@ def transpose(index, value, m, n, coalesced=True):
return index, value return index, value
def t(mat): def t(src):
(row, col), value = mat.coo() (row, col), value = src.coo()
csr_to_csc = mat._storage.csr_to_csc csr2csc = src.storage.csr2csc
storage = mat._storage.__class__( storage = src.storage.__class__(
index=torch.stack([col, row], dim=0)[:, csr_to_csc], index=torch.stack([col, row], dim=0)[:, csr2csc],
value=value[csr_to_csc] if mat.has_value() else None, value=value[csr2csc] if src.has_value() else None,
sparse_size=mat.sparse_size()[::-1], sparse_size=src.sparse_size()[::-1],
rowptr=mat._storage._colptr, rowcount=src.storage._colcount,
colptr=mat._storage._rowptr, rowptr=src.storage._colptr,
csr_to_csc=mat._storage._csc_to_csr, colcount=src.storage._rowcount,
csc_to_csr=csr_to_csc, colptr=src.storage._rowptr,
is_sorted=True) csr2csc=src.storage._csc2csr,
csc2csr=csr2csc,
return mat.__class__.from_storage(storage) is_sorted=True,
)
return src.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment