minibatch-inference.rst 5.36 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-inference:

6.6 Exact Offline Inference on Large Graphs
------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-inference>`

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
Both subgraph sampling and neighborhood sampling are to reduce the
memory and time consumption for training GNNs with GPUs. When performing
inference it is usually better to truly aggregate over all neighbors
instead to get rid of the randomness introduced by sampling. However,
full-graph forward propagation is usually infeasible on GPU due to
limited memory, and slow on CPU due to slow computation. This section
introduces the methodology of full-graph forward propagation with
limited GPU memory via minibatch and neighborhood sampling.

The inference algorithm is different from the training algorithm, as the
representations of all nodes should be computed layer by layer, starting
from the first layer. Specifically, for a particular layer, we need to
compute the output representations of all nodes from this GNN layer in
minibatches. The consequence is that the inference algorithm will have
an outer loop iterating over the layers, and an inner loop iterating
over the minibatches of nodes. In contrast, the training algorithm has
an outer loop iterating over the minibatches of nodes, and an inner loop
iterating over the layers for both neighborhood sampling and message
passing.

The following animation shows how the computation would look like (note
that for every layer only the first three minibatches are drawn).

Jinjing Zhou's avatar
Jinjing Zhou committed
31
.. figure:: https://data.dgl.ai/asset/image/guide_6_6_0.gif
32
33
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
34

35
36
37
38

Implementing Offline Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

39
40
Consider the two-layer GCN we have mentioned in Section 6.1
:ref:`guide-minibatch-node-classification-model`. The way
41
to implement offline inference still involves using
42
43
:class:`~dgl.graphbolt.NeighborSampler`, but sampling for
only one layer at a time.
44
45
46

.. code:: python

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)


Note that offline inference is implemented as a method of the GNN module
because the computation on one layer depends on how messages are aggregated
and combined as well.

.. code:: python

    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size, out_size):
63
            super().__init__()
64
65
66
67
68
69
70
71
72
            self.layers = nn.ModuleList()
            # Three-layer GraphSAGE-mean.
            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
            self.dropout = nn.Dropout(0.5)
            self.hidden_size = hidden_size
            self.out_size = out_size

73
        def forward(self, blocks, x):
74
75
76
77
78
79
80
81
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
                    hidden_x = self.dropout(hidden_x)
            return hidden_x
82
    
83
        def inference(self, graph, features, dataloader, device):
84
85
86
            """
            Offline inference with this module
            """
87
88
            feature = features.read("node", None, "feat")

89
            # Compute representations layer by layer
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            for layer_idx, layer in enumerate(self.layers):
                is_last_layer = layer_idx == len(self.layers) - 1

                y = torch.empty(
                    graph.total_num_nodes,
                    self.out_size if is_last_layer else self.hidden_size,
                    dtype=torch.float32,
                    device=buffer_device,
                    pin_memory=pin_memory,
                )
                feature = feature.to(device)

                for step, data in tqdm(enumerate(dataloader)):
                    x = feature[data.input_nodes]
                    hidden_x = layer(data.blocks[0], x)  # len(blocks) = 1
                    if not is_last_layer:
                        hidden_x = F.relu(hidden_x)
                        hidden_x = self.dropout(hidden_x)
                    # By design, our output nodes are contiguous.
                    y[
                        data.output_nodes[0] : data.output_nodes[-1] + 1
                    ] = hidden_x.to(device)
                feature = y

114
115
            return y

116

117
118
119
120
121
122
123
124
125
Note that for the purpose of computing evaluation metric on the
validation set for model selection we usually don’t have to compute
exact offline inference. The reason is that we need to compute the
representation for every single node on every single layer, which is
usually very costly especially in the semi-supervised regime with a lot
of unlabeled data. Neighborhood sampling will work fine for model
selection and validation.

One can see
126
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/node_classification.py>`__
127
and
128
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/rgcn/hetero_rgcn.py>`__
129
for examples of offline inference.