Unverified Commit 9f444cd7 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[doc] update user guide 6.6 (#6656)

parent 3ea930d0
......@@ -39,63 +39,81 @@ Implementing Offline Inference
Consider the two-layer GCN we have mentioned in Section 6.1
:ref:`guide-minibatch-node-classification-model`. The way
to implement offline inference still involves using
:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`, but sampling for
only one layer at a time. Note that offline inference is implemented as
a method of the GNN module because the computation on one layer depends
on how messages are aggregated and combined as well.
:class:`~dgl.graphbolt.NeighborSampler`, but sampling for
only one layer at a time.
.. code:: python
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Note that offline inference is implemented as a method of the GNN module
because the computation on one layer depends on how messages are aggregated
and combined as well.
.. code:: python
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.hidden_features = hidden_features
self.out_features = out_features
self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
self.n_layers = 2
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
def forward(self, blocks, x):
x_dst = x[:blocks[0].number_of_dst_nodes()]
x = F.relu(self.conv1(blocks[0], (x, x_dst)))
x_dst = x[:blocks[1].number_of_dst_nodes()]
x = F.relu(self.conv2(blocks[1], (x, x_dst)))
return x
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, g, x, batch_size, device):
def inference(self, graph, features, dataloader, device):
"""
Offline inference with this module
"""
feature = features.read("node", None, "feat")
# Compute representations layer by layer
for l, layer in enumerate([self.conv1, self.conv2]):
y = torch.zeros(g.num_nodes(),
self.hidden_features
if l != self.n_layers - 1
else self.out_features)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()), sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False)
# Within a layer, iterate over nodes in batches
for input_nodes, output_nodes, blocks in dataloader:
block = blocks[0]
# Copy the features of necessary input nodes to GPU
h = x[input_nodes].to(device)
# Compute output. Note that this computation is the same
# but only for a single layer.
h_dst = h[:block.number_of_dst_nodes()]
h = F.relu(layer(block, (h, h_dst)))
# Copy to output back to CPU.
y[output_nodes] = h.cpu()
x = y
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(device)
feature = y
return y
Note that for the purpose of computing evaluation metric on the
validation set for model selection we usually don’t have to compute
exact offline inference. The reason is that we need to compute the
......@@ -105,7 +123,7 @@ of unlabeled data. Neighborhood sampling will work fine for model
selection and validation.
One can see
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling.py>`__
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/node_classification.py>`__
and
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/rgcn/hetero_rgcn.py>`__
for examples of offline inference.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment