Commit 6325fa72 authored by rusty1s's avatar rusty1s
Browse files

add documentation

parent 727f3279
......@@ -14,9 +14,9 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple):
data: Data
batch_size: int
n_id: Tensor
offset: Tensor
count: Tensor
n_id: Tensor # The indices of mini-batched nodes
offset: Tensor # The offset of contiguous mini-batched nodes
count: Tensor # The number of contiguous mini-batched nodes
def to(self, *args, **kwargs):
return SubData(self.data.to(*args, **kwargs), self.batch_size,
......@@ -25,8 +25,8 @@ class SubData(NamedTuple):
class SubgraphLoader(DataLoader):
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`."""
generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
neighbors)."""
def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
bipartite: bool = True, log: bool = True, **kwargs):
......@@ -91,8 +91,8 @@ class SubgraphLoader(DataLoader):
class EvalSubgraphLoader(SubgraphLoader):
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`.
generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
neighbors).
In contrast to :class:`SubgraphLoader`, this loader does not generate
subgraphs from randomly sampled mini-batches, and should therefore only be
used for evaluation.
......@@ -102,7 +102,7 @@ class EvalSubgraphLoader(SubgraphLoader):
ptr = ptr[::batch_size]
if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)
ptr = torch.cat([ptr, torch.tensor([data.num_nodes])], dim=0)
super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite,
log=log, shuffle=False, num_workers=0, **kwargs)
......@@ -6,7 +6,8 @@ import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric_autoscale import History, AsyncIOPool, SubgraphLoader
from torch_geometric_autoscale import History, AsyncIOPool
from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
class ScalableGNN(torch.nn.Module):
......@@ -26,9 +27,8 @@ class ScalableGNN(torch.nn.Module):
for _ in range(num_layers - 1)
])
self.pool = None
self.pool: Optional[AsyncIOPool] = None
self._async = False
self.__out__ = None
@property
def emb_device(self):
......@@ -38,15 +38,9 @@ class ScalableGNN(torch.nn.Module):
def device(self):
return self.histories[0]._device
@property
def _out(self):
if self.__out__ is None:
self.__out__ = torch.empty(self.num_nodes, self.out_channels,
pin_memory=True)
return self.__out__
def _apply(self, fn: Callable) -> None:
super(ScalableGNN, self)._apply(fn)
# We only initialize the AsyncIOPool in case histories are on CPU:
if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda'
and self.pool_size is not None
and self.buffer_size is not None):
......@@ -67,13 +61,15 @@ class ScalableGNN(torch.nn.Module):
n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
count: Optional[Tensor] = None,
loader=None,
loader: EvalSubgraphLoader = None,
**kwargs,
) -> Tensor:
if loader is not None:
return self.mini_inference(loader)
# We only perform asynchronous history transfer in case the following
# conditions are met:
self._async = (self.pool is not None and batch_size is not None
and n_id is not None and offset is not None
and count is not None)
......@@ -103,6 +99,7 @@ class ScalableGNN(torch.nn.Module):
n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
r"""Push and pull information from `x` to `history` and vice versa."""
if n_id is None and x.size(0) != self.num_nodes:
return x # Do nothing...
......@@ -122,12 +119,20 @@ class ScalableGNN(torch.nn.Module):
h = history.pull(n_id[batch_size:])
return torch.cat([x[:batch_size], h], dim=0)
else:
out = self.pool.synchronize_pull()[:n_id.numel() - batch_size]
self.pool.async_push(x[:batch_size], offset, count, history.emb)
out = torch.cat([x[:batch_size], out], dim=0)
self.pool.free_pull()
return out
@property
def _out(self):
if self.__out is None:
self.__out = torch.empty(self.num_nodes, self.out_channels,
pin_memory=True)
return self.__out
@torch.no_grad()
def mini_inference(self, loader: SubgraphLoader) -> Tensor:
loader = [data + ({}, ) for data in loader]
......
......@@ -14,8 +14,8 @@ class AsyncIOPool(torch.nn.Module):
super(AsyncIOPool, self).__init__()
self.pool_size = pool_size
self.embedding_dim = embedding_dim
self.buffer_size = buffer_size
self.embedding_dim = embedding_dim
self._device = torch.device('cpu')
self._pull_queue = []
......@@ -61,6 +61,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad()
def async_pull(self, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None:
# Start pulling `src` at ([offset, count] and index positions:
self._pull_index = (self._pull_index + 1) % self.pool_size
data = (self._pull_index, src, offset, count, index)
self._pull_queue.append(data)
......@@ -76,6 +77,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad()
def synchronize_pull(self) -> Tensor:
# Synchronize the next pull command:
idx = self._pull_queue[0][0]
synchronize()
torch.cuda.synchronize(self._pull_stream(idx))
......@@ -83,17 +85,19 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad()
def free_pull(self) -> None:
# Free the buffer space and start pulling from remaining queue:
self._pull_queue.pop(0)
if len(self._pull_queue) >= self.pool_size:
data = self._pull_queue[self.pool_size - 1]
idx, src, offset, count, index = data
self._async_pull(idx, src, offset, count, index)
if len(self._pull_queue) == 0:
elif len(self._pull_queue) == 0:
self._pull_index = -1
@torch.no_grad()
def async_push(self, src: Tensor, offset: Tensor, count: Tensor,
dst: Tensor) -> None:
# Start pushing `src` to ([offset, count] and index positions to `dst`:
self._push_index = (self._push_index + 1) % self.pool_size
self.synchronize_push(self._push_index)
self._push_cache[self._push_index] = src
......@@ -102,6 +106,7 @@ class AsyncIOPool(torch.nn.Module):
@torch.no_grad()
def synchronize_push(self, idx: Optional[int] = None) -> None:
# Synchronize the push command of stream `idx` or all commands:
if idx is None:
for idx in range(self.pool_size):
self.synchronize_push(idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment