Commit b38d0b5e authored by rusty1s's avatar rusty1s
Browse files

fix test

parent c4f318ee
......@@ -9,7 +9,7 @@ from .utils import devices
try:
rowptr = torch.tensor([0, 1])
col = torch.tensor([0])
torch.ops.torch_sparse.partition(rowptr, col, None, 1)
torch.ops.torch_sparse.partition(rowptr, col, None, 1, True)
with_metis = True
except RuntimeError:
with_metis = False
......
......@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False):
def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
assert row is not None or rowptr is not None
assert col is not None
......@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
......@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
......@@ -313,10 +317,20 @@ class SparseStorage(object):
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=self._value,
sparse_sizes=(num_rows, num_cols),
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)
def fill_cache_(self):
self.row()
......@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
trust_data=True,
)
def clone(self):
row = self._row
......@@ -495,11 +520,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value
......@@ -508,9 +542,7 @@ class SparseStorage(object):
return self
else:
return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
value.to(dtype=dtype, non_blocking=non_blocking),
layout='coo')
else:
return self
......@@ -548,11 +580,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
......@@ -587,11 +628,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=new_col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def pin_memory(self):
row = self._row
......@@ -620,11 +670,20 @@ class SparseStorage(object):
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
trust_data=True,
)
def is_pinned(self) -> bool:
is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment