Commit f2d76760 authored by rusty1s's avatar rusty1s
Browse files

add docstring

parent d43ed0b9
......@@ -11,6 +11,33 @@ from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
class ScalableGNN(torch.nn.Module):
r"""An abstract class for implementing scalable GNNs via historical
embeddings.
This class will take care of initializing :obj:`num_layers - 1` historical
embeddings, and provides a convenient interface to push recent node
embeddings to the history and pulling embeddings from the history.
In case historical embeddings are stored on the CPU, they will reside
inside pinned memory, which allows for an asynchronous memory transfers of
histories.
For this, this class maintains a :class:`AsyncIOPool` object that
implements the underlying mechanisms of asynchronous memory transfers.
Args:
num_nodes (int): The number of nodes in the graph.
hidden_channels (int): The number of hidden channels of the model.
As a current restriction, all intermediate node embeddings need to
utilize the same number of features.
num_layers (int): The number of layers of the model.
pool_size (int, optional): The number of pinned CPU buffers for pulling
histories and transfering them to GPU.
Needs to be set in order to make use of asynchronous memory
transfers. (default: :obj:`None`)
buffer_size (int, optional): The size of pinned CPU buffers, i.e. the
maximum number of out-of-mini-batch nodes pulled at once.
Needs to be set in order to make use of asynchronous memory
transfers.
transfers. (default: :obj:`None`)
"""
def __init__(self, num_nodes: int, hidden_channels: int, num_layers: int,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
......@@ -65,6 +92,39 @@ class ScalableGNN(torch.nn.Module):
loader: EvalSubgraphLoader = None,
**kwargs,
) -> Tensor:
r"""Extends the call of forward propagation by immediately start
pulling historical embeddings for each layer asynchronously.
After forward propogation, pushing node embeddings to histories will be
synchronized.
For example, given a mini-batch with
:obj:`n_id = [0, 1, 5, 6, 7, 3, 4]`, where the first 5 nodes
represent the mini-batched nodes, and nodes :obj:`3` and :obj:`4`
denote out-of-mini-batched nodes (i.e. the 1-hop neighbors of the
mini-batch that are not included in the current mini-batch), then
other input arguments are given as:
.. code-block:: python
batch_size = 5
offset = [0, 2, 5]
count = [2, 3]
Args:
x (Tensor, optional): Node feature matrix. (default: :obj:`None`)
adj_t (SparseTensor, optional) The sparse adjacency matrix.
(default: :obj:`None`)
batch_size (int, optional): The in-mini-batch size of nodes.
(default: :obj:`None`)
n_id (Tensor, optional): The global indices of mini-batched and
out-of-mini-batched nodes. (default: :obj:`None`)
offset (Tensor, optional): The offset of mini-batched nodes inside
a utilize a contiguous memory layout. (default: :obj:`None`)
count (Tensor, optional): The number of mini-batched nodes inside a
contiguous memory layout. (default: :obj:`None`)
loader (EvalSubgraphLoader, optional): A subgraph loader used for
evaluating the given GNN in a layer-wise fashsion.
"""
if loader is not None:
return self.mini_inference(loader)
......@@ -100,7 +160,8 @@ class ScalableGNN(torch.nn.Module):
n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
r"""Push and pull information from `x` to `history` and vice versa."""
r"""Pushes and pulls information from :obj:`x` to :obj:`history` and
vice versa."""
if n_id is None and x.size(0) != self.num_nodes:
return x # Do nothing...
......@@ -136,6 +197,9 @@ class ScalableGNN(torch.nn.Module):
@torch.no_grad()
def mini_inference(self, loader: SubgraphLoader) -> Tensor:
r"""An implementation of a layer-wise evaluation of GNNs.
For each layer, :meth:`forward_layer` will be called."""
# We iterate over the loader in a layer-wise fashsion.
# In order to re-use some intermediate representations, we maintain a
# `state` dictionary for each individual mini-batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment