minibatch-sparse.rst 6.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
.. _guide-minibatch-sparse:

6.5 Training GNN with DGL sparse
---------------------------------

This tutorial demonstrates how to use dgl sparse library to sample on graph and
train model. It trains and tests a GraphSAGE model using the sparse sample and
compact operators to sample submatrix from the whole matrix.

Training GNN with DGL sparse is quite similar to
:ref:`guide-minibatch-node-classification-sampler`. The major difference is
the customized sampler and matrix that represents graph.

We have cutomized one sampler in
:ref:`guide-minibatch-customizing-neighborhood-sampler`. In this tutorial, we
will customize another sampler with DGL sparse library as shown below.

.. code:: python

    @functional_datapipe("sample_sparse_neighbor")
    class SparseNeighborSampler(SubgraphSampler):
        def __init__(self, datapipe, matrix, fanouts):
            super().__init__(datapipe)
            self.matrix = matrix
            # Convert fanouts to a list of tensors.
            self.fanouts = []
            for fanout in fanouts:
                if not isinstance(fanout, torch.Tensor):
                    fanout = torch.LongTensor([int(fanout)])
                self.fanouts.insert(0, fanout)

        def _sample_subgraphs(self, seeds):
            sampled_matrices = []
            src = seeds

            #####################################################################
            # (HIGHLIGHT) Using the sparse sample operator to preform random
            # sampling on the neighboring nodes of the seeds nodes. The sparse
            # compact operator is then employed to compact and relabel the sampled
            # matrix, resulting in the sampled matrix and the relabel index.
            #####################################################################
            for fanout in self.fanouts:
                # Sample neighbors.
                sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
                # Compact the sampled matrix.
                compacted_mat, row_ids = sampled_matrix.compact(0)
                sampled_matrices.insert(0, compacted_mat)
                src = row_ids

            return src, sampled_matrices

Another major difference is the matrix that represents graph. Previously we use
:class:`~dgl.graphbolt.FusedCSCSamplingGraph` for sampling. In this tutorial,
we use :class:`~dgl.sparse.SparseMatrix` to represent graph.

.. code:: python

    dataset = gb.BuiltinDataset("ogbn-products").load()
    g = dataset.graph
    # Create sparse.
    N = g.num_nodes
    A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))


The remaining code is almost same as node classification tutorial.

To use this sampler with :class:`~dgl.graphbolt.DataLoader`:

.. code:: python

    datapipe = gb.ItemSampler(ids, batch_size=1024)
    # Customize graphbolt sampler by sparse.
    datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
    # Use grapbolt to fetch features.
    datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
    dataloader = gb.DataLoader(datapipe)

Model definition is shown below:

.. code:: python

    class SAGEConv(nn.Module):
        r"""GraphSAGE layer from `Inductive Representation Learning on
        Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__
        """

        def __init__(
            self,
            in_feats,
            out_feats,
        ):
            super(SAGEConv, self).__init__()
            self._in_src_feats, self._in_dst_feats = in_feats, in_feats
            self._out_feats = out_feats

            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)
            self.reset_parameters()

        def reset_parameters(self):
            gain = nn.init.calculate_gain("relu")
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

        def forward(self, A, feat):
            feat_src = feat
            feat_dst = feat[: A.shape[1]]

            # Aggregator type: mean.
            srcdata = self.fc_neigh(feat_src)
            # Divided by degree.
            D_hat = dglsp.diag(A.sum(0)) ** -1
            A_div = A @ D_hat
            # Conv neighbors.
            dstdata = A_div.T @ srcdata

            rst = self.fc_self(feat_dst) + dstdata
            return rst


    class SAGE(nn.Module):
        def __init__(self, in_size, hid_size, out_size):
            super().__init__()
            self.layers = nn.ModuleList()
            # Three-layer GraphSAGE-gcn.
            self.layers.append(SAGEConv(in_size, hid_size))
            self.layers.append(SAGEConv(hid_size, hid_size))
            self.layers.append(SAGEConv(hid_size, out_size))
            self.dropout = nn.Dropout(0.5)
            self.hid_size = hid_size
            self.out_size = out_size

        def forward(self, sampled_matrices, x):
            hidden_x = x
            for layer_idx, (layer, sampled_matrix) in enumerate(
                zip(self.layers, sampled_matrices)
            ):
                hidden_x = layer(sampled_matrix, hidden_x)
                if layer_idx != len(self.layers) - 1:
                    hidden_x = F.relu(hidden_x)
                    hidden_x = self.dropout(hidden_x)
            return hidden_x


Launch training:

.. code:: python

    features = dataset.feature
    # Create GraphSAGE model.
    in_size = features.size("node", None, "feat")[0]
    num_classes = dataset.tasks[0].metadata["num_classes"]
    out_size = num_classes
    model = SAGE(in_size, 256, out_size).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

    for epoch in range(10):
        model.train()
        total_loss = 0
        for it, data in enumerate(dataloader):
            node_feature = data.node_features["feat"].float()
            blocks = data.sampled_subgraphs
            y = data.labels
            y_hat = model(blocks, node_feature)
            loss = F.cross_entropy(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

For more details, please refer to the `full example
<https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/sparse/graphsage.py>`__.