Unverified Commit fe8c3ce3 authored by Nick Stathas's avatar Nick Stathas Committed by GitHub
Browse files

Skip unnecessary assertions and enable non-blocking data transfers (#195)

* Uses the `trust_data` invariant to skip blocking assertions, when unnecessary, during construction of `SparseStorage` objects.
* Refactors the dtype and device transfer APIs to align with `torch.Tensor` while maintaining backward compatibility.
* No longer constructs dummy tensors when changing dtype or device.
parent 88c6ceb6
...@@ -24,6 +24,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' ...@@ -24,6 +24,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = True if os.getenv('WITH_METIS', '0') == '1' else False WITH_METIS = True if os.getenv('WITH_METIS', '0') == '1' else False
WITH_MTMETIS = True if os.getenv('WITH_MTMETIS', '0') == '1' else False WITH_MTMETIS = True if os.getenv('WITH_MTMETIS', '0') == '1' else False
WITH_SYMBOLS = True if os.getenv('WITH_SYMBOLS', '0') == '1' else False
def get_extensions(): def get_extensions():
extensions = [] extensions = []
...@@ -47,7 +49,7 @@ def get_extensions(): ...@@ -47,7 +49,7 @@ def get_extensions():
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = ['-s'] extra_link_args = [] if WITH_SYMBOLS else ['-s']
info = parallel_info() info = parallel_info()
if ('backend: OpenMP' in info and 'OpenMP not found' not in info if ('backend: OpenMP' in info and 'OpenMP not found' not in info
......
...@@ -41,7 +41,8 @@ class SparseStorage(object): ...@@ -41,7 +41,8 @@ class SparseStorage(object):
colcount: Optional[torch.Tensor] = None, colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None, csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None, csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False): is_sorted: bool = False,
trust_data: bool = False):
assert row is not None or rowptr is not None assert row is not None or rowptr is not None
assert col is not None assert col is not None
...@@ -62,7 +63,7 @@ class SparseStorage(object): ...@@ -62,7 +63,7 @@ class SparseStorage(object):
if rowptr is not None: if rowptr is not None:
assert rowptr.numel() - 1 == M assert rowptr.numel() - 1 == M
elif row is not None and row.numel() > 0: elif row is not None and row.numel() > 0:
assert int(row.max()) < M assert trust_data or int(row.max()) < M
N: int = 0 N: int = 0
if sparse_sizes is None or sparse_sizes[1] is None: if sparse_sizes is None or sparse_sizes[1] is None:
...@@ -73,7 +74,7 @@ class SparseStorage(object): ...@@ -73,7 +74,7 @@ class SparseStorage(object):
assert _N is not None assert _N is not None
N = _N N = _N
if col.numel() > 0: if col.numel() > 0:
assert int(col.max()) < N assert trust_data or int(col.max()) < N
sparse_sizes = (M, N) sparse_sizes = (M, N)
...@@ -163,7 +164,7 @@ class SparseStorage(object): ...@@ -163,7 +164,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=None, return SparseStorage(row=row, rowptr=None, col=col, value=None,
sparse_sizes=(0, 0), rowcount=None, colptr=None, sparse_sizes=(0, 0), rowcount=None, colptr=None,
colcount=None, csr2csc=None, csc2csr=None, colcount=None, csr2csc=None, csc2csr=None,
is_sorted=True) is_sorted=True, trust_data=True)
def has_row(self) -> bool: def has_row(self) -> bool:
return self._row is not None return self._row is not None
...@@ -227,11 +228,19 @@ class SparseStorage(object): ...@@ -227,11 +228,19 @@ class SparseStorage(object):
assert value.device == self._col.device assert value.device == self._col.device
assert value.size(0) == self._col.numel() assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(
value=value, sparse_sizes=self._sparse_sizes, row=self._row,
rowcount=self._rowcount, colptr=self._colptr, rowptr=self._rowptr,
colcount=self._colcount, csr2csc=self._csr2csc, col=self._col,
csc2csr=self._csc2csr, is_sorted=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_sizes(self) -> Tuple[int, int]: def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes return self._sparse_sizes
...@@ -269,11 +278,19 @@ class SparseStorage(object): ...@@ -269,11 +278,19 @@ class SparseStorage(object):
if colcount is not None: if colcount is not None:
colcount = colcount[:-diff_1] colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col, return SparseStorage(
value=self._value, sparse_sizes=sparse_sizes, row=self._row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=self._csr2csc, col=self._col,
csc2csr=self._csc2csr, is_sorted=True) value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_reshape(self, num_rows: int, num_cols: int): def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1 assert num_rows > 0 or num_rows == -1
...@@ -299,7 +316,7 @@ class SparseStorage(object): ...@@ -299,7 +316,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=self._value, return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None, sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None, colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True) csc2csr=None, is_sorted=True, trust_data=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -399,7 +416,7 @@ class SparseStorage(object): ...@@ -399,7 +416,7 @@ class SparseStorage(object):
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None, sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None, colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True) csc2csr=None, is_sorted=True, trust_data=True)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -437,12 +454,19 @@ class SparseStorage(object): ...@@ -437,12 +454,19 @@ class SparseStorage(object):
return len(self.cached_keys()) return len(self.cached_keys())
def copy(self): def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value, value=self._value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr, rowcount=self._rowcount,
colcount=self._colcount, csr2csc=self._csr2csc, colptr=self._colptr,
csc2csr=self._csc2csr, is_sorted=True) colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -475,53 +499,63 @@ class SparseStorage(object): ...@@ -475,53 +499,63 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True, trust_data=True)
def type_as(self, tensor: torch.Tensor): def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value value = self._value
if value is not None: if value is not None:
if tensor.dtype == value.dtype: if dtype == value.dtype:
return self return self
else: else:
return self.set_value(value.type_as(tensor), layout='coo') return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
layout='coo')
else: else:
return self return self
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self._col.device: return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self._col.device:
return self return self
row = self._row row = self._row
if row is not None: if row is not None:
row = row.to(tensor.device, non_blocking=non_blocking) row = row.to(device, non_blocking=non_blocking)
rowptr = self._rowptr rowptr = self._rowptr
if rowptr is not None: if rowptr is not None:
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking) rowptr = rowptr.to(device, non_blocking=non_blocking)
col = self._col.to(tensor.device, non_blocking=non_blocking) col = self._col.to(device, non_blocking=non_blocking)
value = self._value value = self._value
if value is not None: if value is not None:
value = value.to(tensor.device, non_blocking=non_blocking) value = value.to(device, non_blocking=non_blocking)
rowcount = self._rowcount rowcount = self._rowcount
if rowcount is not None: if rowcount is not None:
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking) rowcount = rowcount.to(device, non_blocking=non_blocking)
colptr = self._colptr colptr = self._colptr
if colptr is not None: if colptr is not None:
colptr = colptr.to(tensor.device, non_blocking=non_blocking) colptr = colptr.to(device, non_blocking=non_blocking)
colcount = self._colcount colcount = self._colcount
if colcount is not None: if colcount is not None:
colcount = colcount.to(tensor.device, non_blocking=non_blocking) colcount = colcount.to(device, non_blocking=non_blocking)
csr2csc = self._csr2csc csr2csc = self._csr2csc
if csr2csc is not None: if csr2csc is not None:
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking) csr2csc = csr2csc.to(device, non_blocking=non_blocking)
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking) csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True, trust_data=True)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
def cuda(self): def cuda(self):
new_col = self._col.cuda() new_col = self._col.cuda()
...@@ -557,7 +591,7 @@ class SparseStorage(object): ...@@ -557,7 +591,7 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True, trust_data=True)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -590,7 +624,7 @@ class SparseStorage(object): ...@@ -590,7 +624,7 @@ class SparseStorage(object):
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True, trust_data=True)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
...@@ -19,18 +19,32 @@ class SparseTensor(object): ...@@ -19,18 +19,32 @@ class SparseTensor(object):
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None, Optional[int]]] = None,
is_sorted: bool = False): is_sorted: bool = False,
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col, trust_data: bool = False):
value=value, sparse_sizes=sparse_sizes, self.storage = SparseStorage(
rowcount=None, colptr=None, colcount=None, row=row,
csr2csc=None, csc2csr=None, rowptr=rowptr,
is_sorted=is_sorted) col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted,
trust_data=trust_data)
@classmethod @classmethod
def from_storage(self, storage: SparseStorage): def from_storage(self, storage: SparseStorage):
out = SparseTensor(row=storage._row, rowptr=storage._rowptr, out = SparseTensor(
col=storage._col, value=storage._value, row=storage._row,
sparse_sizes=storage._sparse_sizes, is_sorted=True) rowptr=storage._rowptr,
col=storage._col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
is_sorted=True,
trust_data=True)
out.storage._rowcount = storage._rowcount out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount out.storage._colcount = storage._colcount
...@@ -43,10 +57,11 @@ class SparseTensor(object): ...@@ -43,10 +57,11 @@ class SparseTensor(object):
edge_attr: Optional[torch.Tensor] = None, edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None, Optional[int]]] = None,
is_sorted: bool = False): is_sorted: bool = False,
trust_data: bool = False):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1], return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes, value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted) is_sorted=is_sorted, trust_data=trust_data)
@classmethod @classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True): def from_dense(self, mat: torch.Tensor, has_value: bool = True):
...@@ -65,7 +80,7 @@ class SparseTensor(object): ...@@ -65,7 +80,7 @@ class SparseTensor(object):
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True) is_sorted=True, trust_data=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor, def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
...@@ -80,7 +95,7 @@ class SparseTensor(object): ...@@ -80,7 +95,7 @@ class SparseTensor(object):
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True) is_sorted=True, trust_data=True)
@classmethod @classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True, def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
...@@ -118,8 +133,14 @@ class SparseTensor(object): ...@@ -118,8 +133,14 @@ class SparseTensor(object):
colcount[M:] = 0 colcount[M:] = 0
csr2csc = csc2csr = row csr2csc = csc2csr = row
out = SparseTensor(row=row, rowptr=rowptr, col=col, value=value, out = SparseTensor(
sparse_sizes=(M, N), is_sorted=True) row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(M, N),
is_sorted=True,
trust_data=True)
out.storage._rowcount = rowcount out.storage._rowcount = rowcount
out.storage._colptr = colptr out.storage._colptr = colptr
out.storage._colcount = colcount out.storage._colcount = colcount
...@@ -133,16 +154,24 @@ class SparseTensor(object): ...@@ -133,16 +154,24 @@ class SparseTensor(object):
def clone(self): def clone(self):
return self.from_storage(self.storage.clone()) return self.from_storage(self.storage.clone())
def type_as(self, tensor: torch.Tensor): def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self.storage.value() value = self.storage.value()
if value is None or tensor.dtype == value.dtype: if value is None or dtype == value.dtype:
return self return self
return self.from_storage(self.storage.type_as(tensor)) return self.from_storage(self.storage.type(
dtype=dtype, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self.device(): return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self.device():
return self return self
return self.from_storage(self.storage.device_as(tensor, non_blocking)) return self.from_storage(self.storage.to_device(
device=device, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
# Formats ################################################################# # Formats #################################################################
...@@ -326,8 +355,14 @@ class SparseTensor(object): ...@@ -326,8 +355,14 @@ class SparseTensor(object):
new_row = torch.cat([row, col], dim=0, out=perm)[idx] new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx] new_col = torch.cat([col, row], dim=0, out=perm)[idx]
out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value, out = SparseTensor(
sparse_sizes=(N, N), is_sorted=True) row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=(N, N),
is_sorted=True,
trust_data=True)
return out return out
def detach_(self): def detach_(self):
...@@ -369,7 +404,7 @@ class SparseTensor(object): ...@@ -369,7 +404,7 @@ class SparseTensor(object):
return self.storage.col().device return self.storage.col().device
def cpu(self): def cpu(self):
return self.device_as(torch.tensor(0), non_blocking=False) return self.to_device(device=torch.device('cpu'), non_blocking=False)
def cuda(self): def cuda(self):
return self.from_storage(self.storage.cuda()) return self.from_storage(self.storage.cuda())
...@@ -386,44 +421,34 @@ class SparseTensor(object): ...@@ -386,44 +421,34 @@ class SparseTensor(object):
return torch.is_floating_point(value) if value is not None else True return torch.is_floating_point(value) if value is not None else True
def bfloat16(self): def bfloat16(self):
return self.type_as( return self.type(dtype=torch.bfloat16, non_blocking=False)
torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
def bool(self): def bool(self):
return self.type_as( return self.type(dtype=torch.bool, non_blocking=False)
torch.tensor(0, dtype=torch.bool, device=self.device()))
def byte(self): def byte(self):
return self.type_as( return self.type(dtype=torch.uint8, non_blocking=False)
torch.tensor(0, dtype=torch.uint8, device=self.device()))
def char(self): def char(self):
return self.type_as( return self.type(dtype=torch.int8, non_blocking=False)
torch.tensor(0, dtype=torch.int8, device=self.device()))
def half(self): def half(self):
return self.type_as( return self.type(dtype=torch.half, non_blocking=False)
torch.tensor(0, dtype=torch.half, device=self.device()))
def float(self): def float(self):
return self.type_as( return self.type(dtype=torch.float, non_blocking=False)
torch.tensor(0, dtype=torch.float, device=self.device()))
def double(self): def double(self):
return self.type_as( return self.type(dtype=torch.double, non_blocking=False)
torch.tensor(0, dtype=torch.double, device=self.device()))
def short(self): def short(self):
return self.type_as( return self.type(dtype=torch.short, non_blocking=False)
torch.tensor(0, dtype=torch.short, device=self.device()))
def int(self): def int(self):
return self.type_as( return self.type(dtype=torch.int, non_blocking=False)
torch.tensor(0, dtype=torch.int, device=self.device()))
def long(self): def long(self):
return self.type_as( return self.type(dtype=torch.long, non_blocking=False)
torch.tensor(0, dtype=torch.long, device=self.device()))
# Conversions ############################################################# # Conversions #############################################################
...@@ -472,9 +497,9 @@ def to(self, *args: Optional[List[Any]], ...@@ -472,9 +497,9 @@ def to(self, *args: Optional[List[Any]],
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3] device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None: if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype)) self = self.type(dtype=dtype, non_blocking=non_blocking)
if device is not None: if device is not None:
self = self.device_as(torch.tensor(0., device=device), non_blocking) self = self.to_device(device=device, non_blocking=non_blocking)
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment