Commit a73bb262 authored by rusty1s's avatar rusty1s
Browse files

doc

parent 6325fa72
......@@ -29,6 +29,7 @@ class ScalableGNN(torch.nn.Module):
self.pool: Optional[AsyncIOPool] = None
self._async = False
self.__out: Optional[Tensor] = None
@property
def emb_device(self):
......@@ -135,20 +136,28 @@ class ScalableGNN(torch.nn.Module):
@torch.no_grad()
def mini_inference(self, loader: SubgraphLoader) -> Tensor:
loader = [data + ({}, ) for data in loader]
# We iterate over the loader in a layer-wise fashsion.
# In order to re-use some intermediate representations, we maintain a
# `state` dictionary for each individual mini-batch.
for batch, batch_size, n_id, offset, count, state in loader:
x = batch.x.to(self.device)
adj_t = batch.adj_t.to(self.device)
loader = [sub_data + ({}, ) for sub_data in loader]
# We push the outputs of the first layer to the history:
for data, batch_size, n_id, offset, count, state in loader:
x = data.x.to(self.device)
adj_t = data.adj_t.to(self.device)
out = self.forward_layer(0, x, adj_t, state)[:batch_size]
self.pool.async_push(out, offset, count, self.histories[0].emb)
self.pool.synchronize_push()
for i in range(1, len(self.histories)):
# Pull the complete layer-wise history:
for _, batch_size, n_id, offset, count, _ in loader:
self.pool.async_pull(self.histories[i - 1].emb, offset, count,
n_id[batch_size:])
# Compute new output embeddings one-by-one and start pushing them
# to the history.
for batch, batch_size, n_id, offset, count, state in loader:
adj_t = batch.adj_t.to(self.device)
x = self.pool.synchronize_pull()[:n_id.numel()]
......@@ -157,10 +166,13 @@ class ScalableGNN(torch.nn.Module):
self.pool.free_pull()
self.pool.synchronize_push()
# We pull the histories from the last layer:
for _, batch_size, n_id, offset, count, _ in loader:
self.pool.async_pull(self.histories[-1].emb, offset, count,
n_id[batch_size:])
# And compute final output embeddings, which we write into a private
# output embedding matrix:
for batch, batch_size, n_id, offset, count, state in loader:
adj_t = batch.adj_t.to(self.device)
x = self.pool.synchronize_pull()[:n_id.numel()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment