import warnings import torch import torch_scatter from torch_scatter import scatter_add, segment_add def optional(func, src): return func(src) if src is not None else src class cached_property(object): def __init__(self, func): self.func = func def __get__(self, obj, cls): value = getattr(obj, f'_{self.func.__name__}', None) if value is None: value = self.func(obj) setattr(obj, f'_{self.func.__name__}', value) return value layouts = ['coo', 'csr', 'csc'] def get_layout(layout=None): if layout is None: layout = 'coo' warnings.warn('`layout` argument unset, using default layout ' '"coo". This may lead to unexpected behaviour.') assert layout in layouts return layout class SparseStorage(object): cache_keys = [ 'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr' ] def __init__(self, index, value=None, sparse_size=None, rowcount=None, rowptr=None, colcount=None, colptr=None, csr2csc=None, csc2csr=None, is_sorted=False): assert index.dtype == torch.long assert index.dim() == 2 and index.size(0) == 2 index = index.contiguous() if value is not None: assert value.device == index.device assert value.size(0) == index.size(1) value = value.contiguous() if sparse_size is None: sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist()) if rowcount is not None: assert rowcount.dtype == torch.long assert rowcount.device == index.device assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0] if rowptr is not None: assert rowptr.dtype == torch.long assert rowptr.device == index.device assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0] if colcount is not None: assert colcount.dtype == torch.long assert colcount.device == index.device assert colcount.dim() == 1 and colcount.numel() == sparse_size[1] if colptr is not None: assert colptr.dtype == torch.long assert colptr.device == index.device assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1] if csr2csc is not None: assert csr2csc.dtype == torch.long assert csr2csc.device == index.device assert csr2csc.dim() == 1 assert csr2csc.numel() == index.size(1) if csc2csr is not None: assert csc2csr.dtype == torch.long assert csc2csr.device == index.device assert csc2csr.dim() == 1 assert csc2csr.numel() == index.size(1) if not is_sorted: idx = sparse_size[1] * index[0] + index[1] # Only sort if necessary... if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any(): perm = idx.argsort() index = index[:, perm] value = None if value is None else value[perm] csr2csc = None csc2csr = None self._index = index self._value = value self._sparse_size = sparse_size self._rowcount = rowcount self._rowptr = rowptr self._colcount = colcount self._colptr = colptr self._csr2csc = csr2csc self._csc2csr = csc2csr @property def index(self): return self._index @property def row(self): return self._index[0] @property def col(self): return self._index[1] def has_value(self): return self._value is not None @property def value(self): return self._value def set_value_(self, value, layout=None): assert value.device == self._index.device assert value.size(0) == self._index.size(1) if value is not None and get_layout(layout) == 'csc': value = value[self.csc2csr] self._value = value return self def set_value(self, value, layout=None): assert value.device == self._index.device assert value.size(0) == self._index.size(1) if value is not None and get_layout(layout) == 'csc': value = value[self.csc2csr] return self.__class__( self._index, value, self._sparse_size, self._rowcount, self._rowptr, self._colcount, self._colptr, self._csr2csc, self._csc2csr, is_sorted=True, ) def sparse_size(self, dim=None): return self._sparse_size if dim is None else self._sparse_size[dim] def sparse_resize_(self, *sizes): assert len(sizes) == 2 self._sparse_size == sizes return self @cached_property def rowcount(self): one = torch.ones_like(self.row) return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0]) @cached_property def rowptr(self): rowcount = self.rowcount return torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0) @cached_property def colcount(self): one = torch.ones_like(self.col) return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1]) @cached_property def colptr(self): colcount = self.colcount return torch.cat([colcount.new_zeros(1), colcount.cumsum(0)], dim=0) @cached_property def csr2csc(self): idx = self._sparse_size[0] * self.col + self.row return idx.argsort() @cached_property def csc2csr(self): return self.csr2csc.argsort() def is_coalesced(self): idx = self.sparse_size(1) * self.row + self.col mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0) return mask.all().item() def coalesce(self, reduce='add'): idx = self.sparse_size(1) * self.row + self.col mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0) if mask.all(): # Already coalesced return self index = self.index[:, mask] value = self.value if self.has_value(): assert reduce in ['add', 'mean', 'min', 'max'] idx = mask.cumsum(0) - 1 op = getattr(torch_scatter, f'scatter_{reduce}') value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1) value = value[0] if isinstance(value, tuple) else value return self.__class__(index, value, self.sparse_size(), is_sorted=True) def cached_keys(self): return [ key for key in self.cache_keys if getattr(self, f'_{key}', None) is not None ] def fill_cache_(self, *args): for arg in args or self.cache_keys: getattr(self, arg) return self def clear_cache_(self, *args): for arg in args or self.cache_keys: setattr(self, f'_{arg}', None) return self def __copy__(self): return self.apply(lambda x: x) def clone(self): return self.apply(lambda x: x.clone()) def __deepcopy__(self, memo): new_storage = self.clone() memo[id(self)] = new_storage return new_storage def apply_value_(self, func): self._value = optional(func, self._value) return self def apply_value(self, func): return self.__class__( self._index, optional(func, self._value), self._sparse_size, self._rowcount, self._rowptr, self._colcount, self._colptr, self._csr2csc, self._csc2csr, is_sorted=True, ) def apply_(self, func): self._index = func(self._index) self._value = optional(func, self._value) for key in self.cached_keys(): setattr(self, f'_{key}', func, getattr(self, f'_{key}')) return self def apply(self, func): return self.__class__( func(self._index), optional(func, self._value), self._sparse_size, optional(func, self._rowcount), optional(func, self._rowptr), optional(func, self._colcount), optional(func, self._colptr), optional(func, self._csr2csc), optional(func, self._csc2csr), is_sorted=True, ) def map(self, func): data = [func(self.index)] if self.has_value(): data += [func(self.value)] data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()] return data if __name__ == '__main__': from torch_geometric.datasets import Reddit, Planetoid # noqa import time # noqa import copy # noqa device = 'cuda' if torch.cuda.is_available() else 'cpu' # dataset = Reddit('/tmp/Reddit') dataset = Planetoid('/tmp/Cora', 'Cora') data = dataset[0].to(device) edge_index = data.edge_index storage = SparseStorage(edge_index, is_sorted=True) t = time.perf_counter() storage.fill_cache_() print(time.perf_counter() - t) t = time.perf_counter() storage.clear_cache_() storage.fill_cache_() print(time.perf_counter() - t) print(storage) # storage = storage.clone() # print(storage) storage = copy.copy(storage) print(storage) print(id(storage)) storage = copy.deepcopy(storage) print(storage) storage.fill_cache_() storage.clear_cache_()