Unverified Commit fe8c3ce3 authored by Nick Stathas's avatar Nick Stathas Committed by GitHub
Browse files

Skip unnecessary assertions and enable non-blocking data transfers (#195)

* Uses the `trust_data` invariant to skip blocking assertions, when unnecessary, during construction of `SparseStorage` objects.
* Refactors the dtype and device transfer APIs to align with `torch.Tensor` while maintaining backward compatibility.
* No longer constructs dummy tensors when changing dtype or device.
parent 88c6ceb6
......@@ -24,6 +24,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = True if os.getenv('WITH_METIS', '0') == '1' else False
WITH_MTMETIS = True if os.getenv('WITH_MTMETIS', '0') == '1' else False
WITH_SYMBOLS = True if os.getenv('WITH_SYMBOLS', '0') == '1' else False
def get_extensions():
extensions = []
......@@ -47,7 +49,7 @@ def get_extensions():
extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = ['-s']
extra_link_args = [] if WITH_SYMBOLS else ['-s']
info = parallel_info()
if ('backend: OpenMP' in info and 'OpenMP not found' not in info
......
......@@ -41,7 +41,8 @@ class SparseStorage(object):
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False):
is_sorted: bool = False,
trust_data: bool = False):
assert row is not None or rowptr is not None
assert col is not None
......@@ -62,7 +63,7 @@ class SparseStorage(object):
if rowptr is not None:
assert rowptr.numel() - 1 == M
elif row is not None and row.numel() > 0:
assert int(row.max()) < M
assert trust_data or int(row.max()) < M
N: int = 0
if sparse_sizes is None or sparse_sizes[1] is None:
......@@ -73,7 +74,7 @@ class SparseStorage(object):
assert _N is not None
N = _N
if col.numel() > 0:
assert int(col.max()) < N
assert trust_data or int(col.max()) < N
sparse_sizes = (M, N)
......@@ -163,7 +164,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=None,
sparse_sizes=(0, 0), rowcount=None, colptr=None,
colcount=None, csr2csc=None, csc2csr=None,
is_sorted=True)
is_sorted=True, trust_data=True)
def has_row(self) -> bool:
return self._row is not None
......@@ -227,11 +228,19 @@ class SparseStorage(object):
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
......@@ -269,11 +278,19 @@ class SparseStorage(object):
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
......@@ -299,7 +316,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
csc2csr=None, is_sorted=True, trust_data=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -399,7 +416,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
csc2csr=None, is_sorted=True, trust_data=True)
def fill_cache_(self):
self.row()
......@@ -437,12 +454,19 @@ class SparseStorage(object):
return len(self.cached_keys())
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def clone(self):
row = self._row
......@@ -475,53 +499,63 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def type_as(self, tensor: torch.Tensor):
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value
if value is not None:
if tensor.dtype == value.dtype:
if dtype == value.dtype:
return self
else:
return self.set_value(value.type_as(tensor), layout='coo')
return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
layout='coo')
else:
return self
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self._col.device:
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(tensor.device, non_blocking=non_blocking)
row = row.to(device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking)
col = self._col.to(tensor.device, non_blocking=non_blocking)
rowptr = rowptr.to(device, non_blocking=non_blocking)
col = self._col.to(device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(tensor.device, non_blocking=non_blocking)
value = value.to(device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking)
rowcount = rowcount.to(device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(tensor.device, non_blocking=non_blocking)
colptr = colptr.to(device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(tensor.device, non_blocking=non_blocking)
colcount = colcount.to(device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking)
csr2csc = csr2csc.to(device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
def cuda(self):
new_col = self._col.cuda()
......@@ -557,7 +591,7 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def pin_memory(self):
row = self._row
......@@ -590,7 +624,7 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def is_pinned(self) -> bool:
is_pinned = True
......
......@@ -19,18 +19,32 @@ class SparseTensor(object):
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None,
csr2csc=None, csc2csr=None,
is_sorted=is_sorted)
is_sorted: bool = False,
trust_data: bool = False):
self.storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted,
trust_data=trust_data)
@classmethod
def from_storage(self, storage: SparseStorage):
out = SparseTensor(row=storage._row, rowptr=storage._rowptr,
col=storage._col, value=storage._value,
sparse_sizes=storage._sparse_sizes, is_sorted=True)
out = SparseTensor(
row=storage._row,
rowptr=storage._rowptr,
col=storage._col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
is_sorted=True,
trust_data=True)
out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount
......@@ -43,10 +57,11 @@ class SparseTensor(object):
edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False):
is_sorted: bool = False,
trust_data: bool = False):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted)
is_sorted=is_sorted, trust_data=trust_data)
@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
......@@ -65,7 +80,7 @@ class SparseTensor(object):
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
is_sorted=True, trust_data=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
......@@ -80,7 +95,7 @@ class SparseTensor(object):
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
is_sorted=True, trust_data=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
......@@ -118,8 +133,14 @@ class SparseTensor(object):
colcount[M:] = 0
csr2csc = csc2csr = row
out = SparseTensor(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=(M, N), is_sorted=True)
out = SparseTensor(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(M, N),
is_sorted=True,
trust_data=True)
out.storage._rowcount = rowcount
out.storage._colptr = colptr
out.storage._colcount = colcount
......@@ -133,16 +154,24 @@ class SparseTensor(object):
def clone(self):
return self.from_storage(self.storage.clone())
def type_as(self, tensor: torch.Tensor):
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self.storage.value()
if value is None or tensor.dtype == value.dtype:
if value is None or dtype == value.dtype:
return self
return self.from_storage(self.storage.type_as(tensor))
return self.from_storage(self.storage.type(
dtype=dtype, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self.device():
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self.device():
return self
return self.from_storage(self.storage.device_as(tensor, non_blocking))
return self.from_storage(self.storage.to_device(
device=device, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
# Formats #################################################################
......@@ -326,8 +355,14 @@ class SparseTensor(object):
new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=(N, N), is_sorted=True)
out = SparseTensor(
row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=(N, N),
is_sorted=True,
trust_data=True)
return out
def detach_(self):
......@@ -369,7 +404,7 @@ class SparseTensor(object):
return self.storage.col().device
def cpu(self):
return self.device_as(torch.tensor(0), non_blocking=False)
return self.to_device(device=torch.device('cpu'), non_blocking=False)
def cuda(self):
return self.from_storage(self.storage.cuda())
......@@ -386,44 +421,34 @@ class SparseTensor(object):
return torch.is_floating_point(value) if value is not None else True
def bfloat16(self):
return self.type_as(
torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
return self.type(dtype=torch.bfloat16, non_blocking=False)
def bool(self):
return self.type_as(
torch.tensor(0, dtype=torch.bool, device=self.device()))
return self.type(dtype=torch.bool, non_blocking=False)
def byte(self):
return self.type_as(
torch.tensor(0, dtype=torch.uint8, device=self.device()))
return self.type(dtype=torch.uint8, non_blocking=False)
def char(self):
return self.type_as(
torch.tensor(0, dtype=torch.int8, device=self.device()))
return self.type(dtype=torch.int8, non_blocking=False)
def half(self):
return self.type_as(
torch.tensor(0, dtype=torch.half, device=self.device()))
return self.type(dtype=torch.half, non_blocking=False)
def float(self):
return self.type_as(
torch.tensor(0, dtype=torch.float, device=self.device()))
return self.type(dtype=torch.float, non_blocking=False)
def double(self):
return self.type_as(
torch.tensor(0, dtype=torch.double, device=self.device()))
return self.type(dtype=torch.double, non_blocking=False)
def short(self):
return self.type_as(
torch.tensor(0, dtype=torch.short, device=self.device()))
return self.type(dtype=torch.short, non_blocking=False)
def int(self):
return self.type_as(
torch.tensor(0, dtype=torch.int, device=self.device()))
return self.type(dtype=torch.int, non_blocking=False)
def long(self):
return self.type_as(
torch.tensor(0, dtype=torch.long, device=self.device()))
return self.type(dtype=torch.long, non_blocking=False)
# Conversions #############################################################
......@@ -472,9 +497,9 @@ def to(self, *args: Optional[List[Any]],
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype))
self = self.type(dtype=dtype, non_blocking=non_blocking)
if device is not None:
self = self.device_as(torch.tensor(0., device=device), non_blocking)
self = self.to_device(device=device, non_blocking=non_blocking)
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment