Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9f444cd7
Unverified
Commit
9f444cd7
authored
Dec 01, 2023
by
Rhett Ying
Committed by
GitHub
Dec 01, 2023
Browse files
[doc] update user guide 6.6 (#6656)
parent
3ea930d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
47 deletions
+65
-47
docs/source/guide/minibatch-inference.rst
docs/source/guide/minibatch-inference.rst
+65
-47
No files found.
docs/source/guide/minibatch-inference.rst
View file @
9f444cd7
...
...
@@ -39,63 +39,81 @@ Implementing Offline Inference
Consider the two-layer GCN we have mentioned in Section 6.1
:ref:`guide-minibatch-node-classification-model`. The way
to implement offline inference still involves using
:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`, but sampling for
only one layer at a time. Note that offline inference is implemented as
a method of the GNN module because the computation on one layer depends
on how messages are aggregated and combined as well.
:class:`~dgl.graphbolt.NeighborSampler`, but sampling for
only one layer at a time.
.. code:: python
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Note that offline inference is implemented as a method of the GNN module
because the computation on one layer depends on how messages are aggregated
and combined as well.
.. code:: python
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.hidden_features = hidden_features
self.out_features = out_features
self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
self.n_layers = 2
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
def forward(self, blocks, x):
x_dst = x[:blocks[0].number_of_dst_nodes()]
x = F.relu(self.conv1(blocks[0], (x, x_dst)))
x_dst = x[:blocks[1].number_of_dst_nodes()]
x = F.relu(self.conv2(blocks[1], (x, x_dst)))
return x
def inference(self, g, x, batch_size, device):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader, device):
"""
Offline inference with this module
"""
feature = features.read("node", None, "feat")
# Compute representations layer by layer
for l, layer in enumerate([self.conv1, self.conv2]):
y = torch.zeros(g.num_nodes(),
self.hidden_features
if l != self.n_layers - 1
else self.out_features)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()), sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False)
# Within a layer, iterate over nodes in batches
for input_nodes, output_nodes, blocks in dataloader:
block = blocks[0]
# Copy the features of necessary input nodes to GPU
h = x[input_nodes].to(device)
# Compute output. Note that this computation is the same
# but only for a single layer.
h_dst = h[:block.number_of_dst_nodes()]
h = F.relu(layer(block, (h, h_dst)))
# Copy to output back to CPU.
y[output_nodes] = h.cpu()
x = y
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(device)
feature = y
return y
Note that for the purpose of computing evaluation metric on the
validation set for model selection we usually don’t have to compute
exact offline inference. The reason is that we need to compute the
...
...
@@ -105,7 +123,7 @@ of unlabeled data. Neighborhood sampling will work fine for model
selection and validation.
One can see
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/
pytorch/graphsage/train_sampling
.py>`__
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/
sampling/graphbolt/node_classification
.py>`__
and
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/
pytorch
/rgcn
-
hetero
/entity_classify_mb
.py>`__
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/
sampling/graphbolt
/rgcn
/
hetero
_rgcn
.py>`__
for examples of offline inference.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment