Commit 6325fa72 authored by rusty1s's avatar rusty1s
Browse files

add documentation

parent 727f3279
...@@ -14,9 +14,9 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop ...@@ -14,9 +14,9 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple): class SubData(NamedTuple):
data: Data data: Data
batch_size: int batch_size: int
n_id: Tensor n_id: Tensor # The indices of mini-batched nodes
offset: Tensor offset: Tensor # The offset of contiguous mini-batched nodes
count: Tensor count: Tensor # The number of contiguous mini-batched nodes
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
return SubData(self.data.to(*args, **kwargs), self.batch_size, return SubData(self.data.to(*args, **kwargs), self.batch_size,
...@@ -25,8 +25,8 @@ class SubData(NamedTuple): ...@@ -25,8 +25,8 @@ class SubData(NamedTuple):
class SubgraphLoader(DataLoader): class SubgraphLoader(DataLoader):
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object, r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
:obj:`ptr`.""" neighbors)."""
def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1, def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
bipartite: bool = True, log: bool = True, **kwargs): bipartite: bool = True, log: bool = True, **kwargs):
...@@ -91,8 +91,8 @@ class SubgraphLoader(DataLoader): ...@@ -91,8 +91,8 @@ class SubgraphLoader(DataLoader):
class EvalSubgraphLoader(SubgraphLoader): class EvalSubgraphLoader(SubgraphLoader):
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object, r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
:obj:`ptr`. neighbors).
In contrast to :class:`SubgraphLoader`, this loader does not generate In contrast to :class:`SubgraphLoader`, this loader does not generate
subgraphs from randomly sampled mini-batches, and should therefore only be subgraphs from randomly sampled mini-batches, and should therefore only be
used for evaluation. used for evaluation.
...@@ -102,7 +102,7 @@ class EvalSubgraphLoader(SubgraphLoader): ...@@ -102,7 +102,7 @@ class EvalSubgraphLoader(SubgraphLoader):
ptr = ptr[::batch_size] ptr = ptr[::batch_size]
if int(ptr[-1]) != data.num_nodes: if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0) ptr = torch.cat([ptr, torch.tensor([data.num_nodes])], dim=0)
super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite, super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite,
log=log, shuffle=False, num_workers=0, **kwargs) log=log, shuffle=False, num_workers=0, **kwargs)
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
from torch import Tensor from torch import Tensor
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric_autoscale import History, AsyncIOPool, SubgraphLoader from torch_geometric_autoscale import History, AsyncIOPool
from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
class ScalableGNN(torch.nn.Module): class ScalableGNN(torch.nn.Module):
...@@ -26,9 +27,8 @@ class ScalableGNN(torch.nn.Module): ...@@ -26,9 +27,8 @@ class ScalableGNN(torch.nn.Module):
for _ in range(num_layers - 1) for _ in range(num_layers - 1)
]) ])
self.pool = None self.pool: Optional[AsyncIOPool] = None
self._async = False self._async = False
self.__out__ = None
@property @property
def emb_device(self): def emb_device(self):
...@@ -38,15 +38,9 @@ class ScalableGNN(torch.nn.Module): ...@@ -38,15 +38,9 @@ class ScalableGNN(torch.nn.Module):
def device(self): def device(self):
return self.histories[0]._device return self.histories[0]._device
@property
def _out(self):
if self.__out__ is None:
self.__out__ = torch.empty(self.num_nodes, self.out_channels,
pin_memory=True)
return self.__out__
def _apply(self, fn: Callable) -> None: def _apply(self, fn: Callable) -> None:
super(ScalableGNN, self)._apply(fn) super(ScalableGNN, self)._apply(fn)
# We only initialize the AsyncIOPool in case histories are on CPU:
if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda' if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda'
and self.pool_size is not None and self.pool_size is not None
and self.buffer_size is not None): and self.buffer_size is not None):
...@@ -67,13 +61,15 @@ class ScalableGNN(torch.nn.Module): ...@@ -67,13 +61,15 @@ class ScalableGNN(torch.nn.Module):
n_id: Optional[Tensor] = None, n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None, count: Optional[Tensor] = None,
loader=None, loader: EvalSubgraphLoader = None,
**kwargs, **kwargs,
) -> Tensor: ) -> Tensor:
if loader is not None: if loader is not None:
return self.mini_inference(loader) return self.mini_inference(loader)
# We only perform asynchronous history transfer in case the following
# conditions are met:
self._async = (self.pool is not None and batch_size is not None self._async = (self.pool is not None and batch_size is not None
and n_id is not None and offset is not None and n_id is not None and offset is not None
and count is not None) and count is not None)
...@@ -103,6 +99,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -103,6 +99,7 @@ class ScalableGNN(torch.nn.Module):
n_id: Optional[Tensor] = None, n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor: count: Optional[Tensor] = None) -> Tensor:
r"""Push and pull information from `x` to `history` and vice versa."""
if n_id is None and x.size(0) != self.num_nodes: if n_id is None and x.size(0) != self.num_nodes:
return x # Do nothing... return x # Do nothing...
...@@ -122,12 +119,20 @@ class ScalableGNN(torch.nn.Module): ...@@ -122,12 +119,20 @@ class ScalableGNN(torch.nn.Module):
h = history.pull(n_id[batch_size:]) h = history.pull(n_id[batch_size:])
return torch.cat([x[:batch_size], h], dim=0) return torch.cat([x[:batch_size], h], dim=0)
else:
out = self.pool.synchronize_pull()[:n_id.numel() - batch_size] out = self.pool.synchronize_pull()[:n_id.numel() - batch_size]
self.pool.async_push(x[:batch_size], offset, count, history.emb) self.pool.async_push(x[:batch_size], offset, count, history.emb)
out = torch.cat([x[:batch_size], out], dim=0) out = torch.cat([x[:batch_size], out], dim=0)
self.pool.free_pull() self.pool.free_pull()
return out return out
@property
def _out(self):
if self.__out is None:
self.__out = torch.empty(self.num_nodes, self.out_channels,
pin_memory=True)
return self.__out
@torch.no_grad() @torch.no_grad()
def mini_inference(self, loader: SubgraphLoader) -> Tensor: def mini_inference(self, loader: SubgraphLoader) -> Tensor:
loader = [data + ({}, ) for data in loader] loader = [data + ({}, ) for data in loader]
......
...@@ -14,8 +14,8 @@ class AsyncIOPool(torch.nn.Module): ...@@ -14,8 +14,8 @@ class AsyncIOPool(torch.nn.Module):
super(AsyncIOPool, self).__init__() super(AsyncIOPool, self).__init__()
self.pool_size = pool_size self.pool_size = pool_size
self.embedding_dim = embedding_dim
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.embedding_dim = embedding_dim
self._device = torch.device('cpu') self._device = torch.device('cpu')
self._pull_queue = [] self._pull_queue = []
...@@ -61,6 +61,7 @@ class AsyncIOPool(torch.nn.Module): ...@@ -61,6 +61,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def async_pull(self, src: Tensor, offset: Optional[Tensor], def async_pull(self, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None: count: Optional[Tensor], index: Tensor) -> None:
# Start pulling `src` at ([offset, count] and index positions:
self._pull_index = (self._pull_index + 1) % self.pool_size self._pull_index = (self._pull_index + 1) % self.pool_size
data = (self._pull_index, src, offset, count, index) data = (self._pull_index, src, offset, count, index)
self._pull_queue.append(data) self._pull_queue.append(data)
...@@ -76,6 +77,7 @@ class AsyncIOPool(torch.nn.Module): ...@@ -76,6 +77,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def synchronize_pull(self) -> Tensor: def synchronize_pull(self) -> Tensor:
# Synchronize the next pull command:
idx = self._pull_queue[0][0] idx = self._pull_queue[0][0]
synchronize() synchronize()
torch.cuda.synchronize(self._pull_stream(idx)) torch.cuda.synchronize(self._pull_stream(idx))
...@@ -83,17 +85,19 @@ class AsyncIOPool(torch.nn.Module): ...@@ -83,17 +85,19 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def free_pull(self) -> None: def free_pull(self) -> None:
# Free the buffer space and start pulling from remaining queue:
self._pull_queue.pop(0) self._pull_queue.pop(0)
if len(self._pull_queue) >= self.pool_size: if len(self._pull_queue) >= self.pool_size:
data = self._pull_queue[self.pool_size - 1] data = self._pull_queue[self.pool_size - 1]
idx, src, offset, count, index = data idx, src, offset, count, index = data
self._async_pull(idx, src, offset, count, index) self._async_pull(idx, src, offset, count, index)
if len(self._pull_queue) == 0: elif len(self._pull_queue) == 0:
self._pull_index = -1 self._pull_index = -1
@torch.no_grad() @torch.no_grad()
def async_push(self, src: Tensor, offset: Tensor, count: Tensor, def async_push(self, src: Tensor, offset: Tensor, count: Tensor,
dst: Tensor) -> None: dst: Tensor) -> None:
# Start pushing `src` to ([offset, count] and index positions to `dst`:
self._push_index = (self._push_index + 1) % self.pool_size self._push_index = (self._push_index + 1) % self.pool_size
self.synchronize_push(self._push_index) self.synchronize_push(self._push_index)
self._push_cache[self._push_index] = src self._push_cache[self._push_index] = src
...@@ -102,6 +106,7 @@ class AsyncIOPool(torch.nn.Module): ...@@ -102,6 +106,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def synchronize_push(self, idx: Optional[int] = None) -> None: def synchronize_push(self, idx: Optional[int] = None) -> None:
# Synchronize the push command of stream `idx` or all commands:
if idx is None: if idx is None:
for idx in range(self.pool_size): for idx in range(self.pool_size):
self.synchronize_push(idx) self.synchronize_push(idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment