Commit b38d0b5e authored by rusty1s's avatar rusty1s
Browse files

fix test

parent c4f318ee
...@@ -9,7 +9,7 @@ from .utils import devices ...@@ -9,7 +9,7 @@ from .utils import devices
try: try:
rowptr = torch.tensor([0, 1]) rowptr = torch.tensor([0, 1])
col = torch.tensor([0]) col = torch.tensor([0])
torch.ops.torch_sparse.partition(rowptr, col, None, 1) torch.ops.torch_sparse.partition(rowptr, col, None, 1, True)
with_metis = True with_metis = True
except RuntimeError: except RuntimeError:
with_metis = False with_metis = False
......
...@@ -30,19 +30,21 @@ class SparseStorage(object): ...@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor] _csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor] _csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None, def __init__(
rowptr: Optional[torch.Tensor] = None, self,
col: Optional[torch.Tensor] = None, row: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], col: Optional[torch.Tensor] = None,
Optional[int]]] = None, value: Optional[torch.Tensor] = None,
rowcount: Optional[torch.Tensor] = None, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
colptr: Optional[torch.Tensor] = None, rowcount: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None, colptr: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None, colcount: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None, csr2csc: Optional[torch.Tensor] = None,
is_sorted: bool = False, csc2csr: Optional[torch.Tensor] = None,
trust_data: bool = False): is_sorted: bool = False,
trust_data: bool = False,
):
assert row is not None or rowptr is not None assert row is not None or rowptr is not None
assert col is not None assert col is not None
...@@ -240,7 +242,8 @@ class SparseStorage(object): ...@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def sparse_sizes(self) -> Tuple[int, int]: def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes return self._sparse_sizes
...@@ -290,7 +293,8 @@ class SparseStorage(object): ...@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def sparse_reshape(self, num_rows: int, num_cols: int): def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1 assert num_rows > 0 or num_rows == -1
...@@ -313,10 +317,20 @@ class SparseStorage(object): ...@@ -313,10 +317,20 @@ class SparseStorage(object):
col = idx % num_cols col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value, return SparseStorage(
sparse_sizes=(num_rows, num_cols), rowcount=None, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True, trust_data=True) col=col,
value=self._value,
sparse_sizes=(num_rows, num_cols),
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -413,10 +427,20 @@ class SparseStorage(object): ...@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, rowcount=None, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True, trust_data=True) col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -466,7 +490,8 @@ class SparseStorage(object): ...@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, csc2csr=self._csc2csr,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -495,11 +520,20 @@ class SparseStorage(object): ...@@ -495,11 +520,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True, trust_data=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def type(self, dtype: torch.dtype, non_blocking: bool = False): def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value value = self._value
...@@ -508,9 +542,7 @@ class SparseStorage(object): ...@@ -508,9 +542,7 @@ class SparseStorage(object):
return self return self
else: else:
return self.set_value( return self.set_value(
value.to( value.to(dtype=dtype, non_blocking=non_blocking),
dtype=dtype,
non_blocking=non_blocking),
layout='coo') layout='coo')
else: else:
return self return self
...@@ -548,11 +580,20 @@ class SparseStorage(object): ...@@ -548,11 +580,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking) csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True, trust_data=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking) return self.to_device(device=tensor.device, non_blocking=non_blocking)
...@@ -587,11 +628,20 @@ class SparseStorage(object): ...@@ -587,11 +628,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.cuda() csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=new_col,
csc2csr=csc2csr, is_sorted=True, trust_data=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -620,11 +670,20 @@ class SparseStorage(object): ...@@ -620,11 +670,20 @@ class SparseStorage(object):
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True, trust_data=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment