training.rst 40.5 KB
Newer Older
1
2
.. _guide-training:

Minjie Wang's avatar
Minjie Wang committed
3
Training Graph Neural Networks
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
==============================

Overview
--------

This chapter discusses how to train a graph neural network for node
classification, edge classification, link prediction, and graph
classification for small graph(s), by message passing methods introduced
in :ref:`guide-message-passing` and neural network modules introduced in
:ref:`guide-nn`.

This chapter assumes that your graph as well as all of its node and edge
features can fit into GPU; see :ref:`guide-minibatch` if they cannot.

The following text assumes that the graph(s) and node/edge features are
already prepared. If you plan to use the dataset DGL provides or other
compatible ``DGLDataset`` as is described in :ref:`guide-data-pipeline`, you can
get the graph for a single-graph dataset with something like

.. code:: python

    import dgl
    
    dataset = dgl.data.CiteseerGraphDataset()
    graph = dataset[0]


Note: In this chapter we will use PyTorch as backend.

Heterogeneous Graphs
~~~~~~~~~~~~~~~~~~~~

Sometimes you would like to work on heterogeneous graphs. Here we take a
synthetic heterogeneous graph as an example for demonstrating node
classification, edge classification, and link prediction tasks.

The synthetic heterogeneous graph ``hetero_graph`` has these edge types:

-  ``('user', 'follow', 'user')``
-  ``('user', 'followed-by', 'user')``
-  ``('user', 'click', 'item')``
-  ``('item', 'clicked-by', 'user')``
-  ``('user', 'dislike', 'item')``
-  ``('item', 'disliked-by', 'user')``

.. code:: python

    import numpy as np
    import torch
    
    n_users = 1000
    n_items = 500
    n_follows = 3000
    n_clicks = 5000
    n_dislikes = 500
    n_hetero_features = 10
    n_user_classes = 5
    n_max_clicks = 10
    
    follow_src = np.random.randint(0, n_users, n_follows)
    follow_dst = np.random.randint(0, n_users, n_follows)
    click_src = np.random.randint(0, n_users, n_clicks)
    click_dst = np.random.randint(0, n_items, n_clicks)
    dislike_src = np.random.randint(0, n_users, n_dislikes)
    dislike_dst = np.random.randint(0, n_items, n_dislikes)
    
    hetero_graph = dgl.heterograph({
        ('user', 'follow', 'user'): (follow_src, follow_dst),
        ('user', 'followed-by', 'user'): (follow_dst, follow_src),
        ('user', 'click', 'item'): (click_src, click_dst),
        ('item', 'clicked-by', 'user'): (click_dst, click_src),
        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})
    
    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
    # randomly generate training masks on user nodes and click edges
    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

.. _guide-training-node-classification:

Node Classification/Regression
------------------------------

One of the most popular and widely adopted tasks for graph neural
networks is node classification, where each node in the
training/validation/test set is assigned a ground truth category from a
set of predefined categories. Node regression is similar, where each
node in the training/validation/test set is assigned a ground truth
number.

Overview
~~~~~~~~

To classify nodes, graph neural network performs message passing
discussed in :ref:`guide-message-passing` to utilize the nodes own
features, but also its neighboring node and edge features. Message
passing can be repeated multiple rounds to incorporate information from
larger range of neighborhood.

Writing neural network model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DGL provides a few built-in graph convolution modules that can perform
one round of message passing. In this guide, we choose
:class:`dgl.nn.pytorch.SAGEConv` (also available in MXNet and Tensorflow),
the graph convolution module for GraphSAGE.

Usually for deep learning models on graphs we need a multi-layer graph
neural network, where we do multiple rounds of message passing. This can
be achieved by stacking graph convolution modules as follows.

.. code:: python

    # Contruct a two-layer GNN model
    import dgl.nn as dglnn
    import torch.nn as nn
    import torch.nn.functional as F
    class SAGE(nn.Module):
        def __init__(self, in_feats, hid_feats, out_feats):
            super().__init__()
            self.conv1 = dglnn.SAGEConv(
                in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
            self.conv2 = dglnn.SAGEConv(
                in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
      
        def forward(self, graph, inputs):
            # inputs are features of nodes
            h = self.conv1(graph, inputs)
            h = F.relu(h)
            h = self.conv2(graph, h)
            return h

Note that you can use the model above for not only node classification,
but also obtaining hidden node representations for other downstream
tasks such as
:ref:`guide-training-edge-classification`,
:ref:`guide-training-link-prediction`, or
:ref:`guide-training-graph-classification`.

For a complete list of built-in graph convolution modules, please refer
to :ref:`apinn`.

For more details in how DGL
neural network modules work and how to write a custom neural network
module with message passing please refer to the example in :ref:`guide-nn`.

Training loop
~~~~~~~~~~~~~

Training on the full graph simply involves a forward propagation of the
model defined above, and computing the loss by comparing the prediction
against ground truth labels on the training nodes.

This section uses a DGL built-in dataset
:class:`dgl.data.CiteseerGraphDataset` to
show a training loop. The node features
and labels are stored on its graph instance, and the
training-validation-test split are also stored on the graph as boolean
masks. This is similar to what you have seen in :ref:`guide-data-pipeline`.

.. code:: python

    node_features = graph.ndata['feat']
    node_labels = graph.ndata['label']
    train_mask = graph.ndata['train_mask']
    valid_mask = graph.ndata['val_mask']
    test_mask = graph.ndata['test_mask']
    n_features = node_features.shape[1]
    n_labels = int(node_labels.max().item() + 1)

The following is an example of evaluating your model by accuracy.

.. code:: python

    def evaluate(model, graph, features, labels, mask):
        model.eval()
        with torch.no_grad():
            logits = model(graph, features)
            logits = logits[mask]
            labels = labels[mask]
            _, indices = torch.max(logits, dim=1)
            correct = torch.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)

You can then write our training loop as follows.

.. code:: python

    model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
    opt = torch.optim.Adam(model.parameters())
    
    for epoch in range(10):
        model.train()
        # forward propagation by using all nodes
        logits = model(graph, node_features)
        # compute loss
        loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
        # compute validation accuracy
        acc = evaluate(model, graph, node_features, node_labels, valid_mask)
        # backward propagation
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())
    
        # Save model if necessary.  Omitted in this example.


`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>`__
provides an end-to-end homogeneous graph node classification example.
You could see the corresponding model implementation is in the
``GraphSAGE`` class in the example with adjustable number of layers,
dropout probabilities, and customizable aggregation functions and
nonlinearities.

.. _guide-training-rgcn-node-classification:

Heterogeneous graph
~~~~~~~~~~~~~~~~~~~

If your graph is heterogeneous, you may want to gather message from
neighbors along all edge types. You can use the module
:class:`dgl.nn.pytorch.HeteroGraphConv` (also available in MXNet and Tensorflow)
to perform message passing
on all edge types, then combining different graph convolution modules
for each edge type.

The following code will define a heterogeneous graph convolution module
that first performs a separate graph convolution on each edge type, then
sums the message aggregations on each edge type as the final result for
all node types.

.. code:: python

    # Define a Heterograph Conv model
    import dgl.nn as dglnn
    
    class RGCN(nn.Module):
        def __init__(self, in_feats, hid_feats, out_feats, rel_names):
            super().__init__()
            
            self.conv1 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, hid_feats)
                for rel in rel_names}, aggregate='sum')
            self.conv2 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(hid_feats, out_feats)
                for rel in rel_names}, aggregate='sum')
      
        def forward(self, graph, inputs):
            # inputs are features of nodes
            h = self.conv1(graph, inputs)
            h = {k: F.relu(v) for k, v in h.items()}
            h = self.conv2(graph, h)
            return h

``dgl.nn.HeteroGraphConv`` takes in a dictionary of node types and node
feature tensors as input, and returns another dictionary of node types
and node features.

So given that we have the user and item features in the example above.

.. code:: python

    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    labels = hetero_graph.nodes['user'].data['label']
    train_mask = hetero_graph.nodes['user'].data['train_mask']

One can simply perform a forward propagation as follows:

.. code:: python

    node_features = {'user': user_feats, 'item': item_feats}
    h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
    h_user = h_dict['user']
    h_item = h_dict['item']

Training loop is the same as the one for homogeneous graph, except that
now you have a dictionary of node representations from which you compute
the predictions. For instance, if you are only predicting the ``user``
nodes, you can just extract the ``user`` node embeddings from the
returned dictionary:

.. code:: python

    opt = torch.optim.Adam(model.parameters())
    
    for epoch in range(5):
        model.train()
        # forward propagation by using all nodes and extracting the user embeddings
        logits = model(hetero_graph, node_features)['user']
        # compute loss
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # Compute validation accuracy.  Omitted in this example.
        # backward propagation
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())
    
        # Save model if necessary.  Omitted in the example.


DGL provides an end-to-end example of
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>`__
for node classification. You can see the definition of heterogeneous
graph convolution in ``RelGraphConvLayer`` in the `model implementation
file <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>`__.

.. _guide-training-edge-classification:

Edge Classification/Regression
------------------------------

Sometimes you wish to predict the attributes on the edges of the graph,
or even whether an edge exists or not between two given nodes. In that
case, you would like to have an *edge classification/regression* model.

Here we generate a random graph for edge prediction as a demonstration.

.. code:: ipython3

    src = np.random.randint(0, 100, 500)
    dst = np.random.randint(0, 100, 500)
    # make it symmetric
    edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
    # synthetic node and edge features, as well as edge labels
    edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
    edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
    edge_pred_graph.edata['label'] = torch.randn(1000)
    # synthetic train-validation-test splits
    edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

Overview
~~~~~~~~

From the previous section you have learned how to do node classification
with a multilayer GNN. The same technique can be applied for computing a
hidden representation of any node. The prediction on edges can then be
derived from the representation of their incident nodes.

The most common case of computing the prediction on an edge is to
express it as a parameterized function of the representation of its
incident nodes, and optionally the features on the edge itself.

Model Implementation Difference from Node Classification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Assuming that you compute the node representation with the model from
the previous section, you only need to write another component that
computes the edge prediction with the
:meth:`~dgl.DGLHeteroGraph.apply_edges` method.

For instance, if you would like to compute a score for each edge for
edge regression, the following code computes the dot product of incident
node representations on each edge.

.. code:: python

    import dgl.function as fn
    class DotProductPredictor(nn.Module):
        def forward(self, graph, h):
            # h contains the node representations computed from the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
                return graph.edata['score']

One can also write a prediction function that predicts a vector for each
edge with an MLP. Such vector can be used in further downstream tasks,
e.g. as logits of a categorical distribution.

.. code:: python

    class MLPPredictor(nn.Module):
        def __init__(self, in_features, out_classes):
            super().__init__()
            self.W = nn.Linear(in_features * 2, out_classes)
    
        def apply_edges(self, edges):
            h_u = edges.src['h']
            h_v = edges.dst['h']
            score = self.W(torch.cat([h_u, h_v], 1))
            return {'score': score}
    
        def forward(self, graph, h):
            # h contains the node representations computed from the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h
                graph.apply_edges(self.apply_edges)
                return graph.edata['score']

Training loop
~~~~~~~~~~~~~

Given the node representation computation model and an edge predictor
model, we can easily write a full-graph training loop where we compute
the prediction on all edges.

The following example takes ``SAGE`` in the previous section as the node
representation computation model and ``DotPredictor`` as an edge
predictor model.

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.sage = SAGE(in_features, hidden_features, out_features)
            self.pred = DotProductPredictor()
        def forward(self, g, x):
            h = self.sage(g, x)
            return self.pred(g, h)

In this example, we also assume that the training/validation/test edge
sets are identified by boolean masks on edges. This example also does
not include early stopping and model saving.

.. code:: python

    node_features = edge_pred_graph.ndata['feature']
    edge_label = edge_pred_graph.edata['label']
    train_mask = edge_pred_graph.edata['train_mask']
    model = Model(10, 20, 5)
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(10):
        pred = model(edge_pred_graph, node_features)
        loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


Heterogeneous graph
~~~~~~~~~~~~~~~~~~~

Edge classification on heterogeneous graphs is not very different from
that on homogeneous graphs. If you wish to perform edge classification
on one edge type, you only need to compute the node representation for
all node types, and predict on that edge type with
:meth:`~dgl.DGLHeteroGraph.apply_edges` method.

For example, to make ``DotProductPredictor`` work on one edge type of a
heterogeneous graph, you only need to specify the edge type in
``apply_edges`` method.

.. code:: python

    class HeteroDotProductPredictor(nn.Module):
        def forward(self, graph, h, etype):
            # h contains the node representations for each edge type computed from
            # the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
                return graph.edges[etype].data['score']

You can similarly write a ``HeteroMLPPredictor``.

.. code:: python

    class MLPPredictor(nn.Module):
        def __init__(self, in_features, out_classes):
            super().__init__()
            self.W = nn.Linear(in_features * 2, out_classes)
    
        def apply_edges(self, edges):
            h_u = edges.src['h']
            h_v = edges.dst['h']
            score = self.W(torch.cat([h_u, h_v], 1))
            return {'score': score}
    
        def forward(self, graph, h, etype):
            # h contains the node representations computed from the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
                graph.apply_edges(self.apply_edges, etype=etype)
                return graph.edges[etype].data['score']

The end-to-end model that predicts a score for each edge on a single
edge type will look like this:

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, rel_names):
            super().__init__()
            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
            self.pred = HeteroDotProductPredictor()
        def forward(self, g, x, etype):
            h = self.sage(g, x)
            return self.pred(g, h, etype)

Using the model simply involves feeding the model a dictionary of node
types and features.

.. code:: python

    model = Model(10, 20, 5, hetero_graph.etypes)
    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    label = hetero_graph.edges['click'].data['label']
    train_mask = hetero_graph.edges['click'].data['train_mask']
    node_features = {'user': user_feats, 'item': item_feats}

Then the training loop looks almost the same as that in homogeneous
graph. For instance, if you wish to predict the edge labels on edge type
``click``, then you can simply do

.. code:: python

    opt = torch.optim.Adam(model.parameters())
    for epoch in range(10):
        pred = model(hetero_graph, node_features, 'click')
        loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


Predicting Edge Type of an Existing Edge on a Heterogeneous Graph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Sometimes you may want to predict which type an existing edge belongs
to.

For instance, given the heterogeneous graph above, your task is given an
edge connecting a user and an item, predict whether the user would
``click`` or ``dislike`` an item.

This is a simplified version of rating prediction, which is common in
recommendation literature.

You can use a heterogeneous graph convolution network to obtain the node
representations. For instance, you can still use the RGCN above for this
purpose.

To predict the type of an edge, you can simply repurpose the
``HeteroDotProductPredictor`` above so that it takes in another graph
with only one edge type that merges all the edge types to be
predicted, and emits the score of each type for every edge.

In the example here, you will need a graph that has two node types
``user`` and ``item``, and one single edge type that merges all the
edge types from ``user`` and ``item``, i.e. ``like`` and ``dislike``.
This can be conveniently created using
:meth:`relation slicing <dgl.DGLHeteroGraph.__getitem__>`.

.. code:: python

    dec_graph = hetero_graph['user', :, 'item']

Since the statement above also returns the original edge types as a
feature named ``dgl.ETYPE``, we can use that as labels.

.. code:: python

    edge_label = dec_graph.edata[dgl.ETYPE]

Given the graph above as input to the edge type predictor module, you
can write your predictor module as follows.

.. code:: python

    class HeteroMLPPredictor(nn.Module):
        def __init__(self, in_dims, n_classes):
            super().__init__()
            self.W = nn.Linear(in_dims * 2, n_classes)
    
        def apply_edges(self, edges):
            x = torch.cat([edges.src['h'], edges.dst['h']], 1)
            y = self.W(x)
            return {'score': y}
    
        def forward(self, graph, h):
            # h contains the node representations for each edge type computed from
            # the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
                graph.apply_edges(self.apply_edges)
                return graph.edata['score']

The model that combines the node representation module and the edge type
predictor module is the following:

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, rel_names):
            super().__init__()
            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
            self.pred = HeteroMLPPredictor(out_features, len(rel_names))
        def forward(self, g, x, dec_graph):
            h = self.sage(g, x)
            return self.pred(dec_graph, h)

The training loop then simply be the following:

.. code:: python

    model = Model(10, 20, 5, hetero_graph.etypes)
    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    node_features = {'user': user_feats, 'item': item_feats}
    
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(10):
        logits = model(hetero_graph, node_features, dec_graph)
        loss = F.cross_entropy(logits, edge_label)
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


DGL provides `Graph Convolutional Matrix
Completion <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__
as an example of rating prediction, which is formulated by predicting
the type of an existing edge on a heterogeneous graph. The node
representation module in the `model implementation
file <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__
is called ``GCMCLayer``. The edge type predictor module is called
``BiDecoder``. Both of them are more complicated than the setting
described here.

.. _guide-training-link-prediction:

Link Prediction
---------------

In some other settings you may want to predict whether an edge exists
between two given nodes or not. Such model is called a *link prediction*
model.

Overview
~~~~~~~~

A GNN-based link prediction model represents the likelihood of
connectivity between two nodes :math:`u` and :math:`v` as a function of
:math:`\boldsymbol{h}_u^{(L)}` and :math:`\boldsymbol{h}_v^{(L)}`, their
node representation computed from the multi-layer GNN.

.. math::


   y_{u,v} = \phi(\boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)})

In this section we refer to :math:`y_{u,v}` the *score* between node
:math:`u` and node :math:`v`.

Training a link prediction model involves comparing the scores between
nodes connected by an edge against the scores between an arbitrary pair
of nodes. For example, given an edge connecting :math:`u` and :math:`v`,
we encourage the score between node :math:`u` and :math:`v` to be higher
than the score between node :math:`u` and a sampled node :math:`v'` from
an arbitrary *noise* distribution :math:`v' \sim P_n(v)`. Such
methodology is called *negative sampling*.

There are lots of loss functions that can achieve the behavior above if
minimized. A non-exhaustive list include:

-  Cross-entropy loss:
   :math:`\mathcal{L} = - \log \sigma (y_{u,v}) - \sum_{v_i \sim P_n(v), i=1,\dots,k}\log \left[ 1 - \sigma (y_{u,v_i})\right]`
-  BPR loss:
   :math:`\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} - \log \sigma (y_{u,v} - y_{u,v_i})`
-  Margin loss:
   :math:`\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} \max(0, M - y_{u, v} + y_{u, v_i})`,
   where :math:`M` is a constant hyperparameter.

You may find this idea familiar if you know what `implicit
feedback <https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf>`__ or
`noise-contrastive
estimation <http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf>`__
is.

Model Implementation Difference from Edge Classification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The neural network model to compute the score between :math:`u` and
:math:`v` is identical to the edge regression model described above.

Here is an example of using dot product to compute the scores on edges.

.. code:: python

    class DotProductPredictor(nn.Module):
        def forward(self, graph, h):
            # h contains the node representations computed from the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
                return graph.edata['score']

Training loop
~~~~~~~~~~~~~

Because our score prediction model operates on graphs, we need to
express the negative examples as another graph. The graph will contain
all negative node pairs as edges.

The following shows an example of expressing negative examples as a
graph. Each edge :math:`(u,v)` gets :math:`k` negative examples
:math:`(u,v_i)` where :math:`v_i` is sampled from a uniform
distribution.

.. code:: python

    def construct_negative_graph(graph, k):
        src, dst = graph.edges()
    
        neg_src = src.repeat_interleave(k)
        neg_dst = torch.randint(0, graph.number_of_nodes(), (len(src) * k,))
        return dgl.graph((neg_src, neg_dst), num_nodes=graph.number_of_nodes())

The model that predicts edge scores is the same as that of edge
classification/regression.

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.sage = SAGE(in_features, hidden_features, out_features)
            self.pred = DotProductPredictor()
        def forward(self, g, neg_g, x):
            h = self.sage(g, x)
            return self.pred(g, h), self.pred(neg_g, h)

The training loop then repeatedly constructs the negative graph and
computes loss.

.. code:: python

    def compute_loss(pos_score, neg_score):
        # Margin loss
        n_edges = pos_score.shape[0]
        return (1 - neg_score.view(n_edges, -1) + pos_score.unsqueeze(1)).clamp(min=0).mean()
    
    node_features = graph.ndata['feat']
    n_features = node_features.shape[1]
    k = 5
    model = Model(n_features, 100, 100)
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(10):
        negative_graph = construct_negative_graph(graph, k)
        pos_score, neg_score = model(graph, negative_graph, node_features)
        loss = compute_loss(pos_score, neg_score)
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


After training, the node representation can be obtained via

.. code:: python

    node_embeddings = model.sage(graph, node_features)

There are multiple ways of using the node embeddings. Examples include
training downstream classifiers, or doing nearest neighbor search or
maximum inner product search for relevant entity recommendation.

Heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~

Link prediction on heterogeneous graphs is not very different from that
on homogeneous graphs. The following assumes that we are predicting on
one edge type, and it is easy to extend it to multiple edge types.

For example, you can reuse the ``HeteroDotProductPredictor`` above for
computing the scores of the edges of an edge type for link prediction.

.. code:: python

    class HeteroDotProductPredictor(nn.Module):
        def forward(self, graph, h, etype):
            # h contains the node representations for each edge type computed from
            # the GNN above.
            with graph.local_scope():
                graph.ndata['h'] = h
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
                return graph.edges[etype].data['score']

To perform negative sampling, one can construct a negative graph for the
edge type you are performing link prediction on as well.

.. code:: python

    def construct_negative_graph(graph, k, etype):
        utype, _, vtype = etype
        src, dst = graph.edges(etype=etype)
        neg_src = src.repeat_interleave(k)
        neg_dst = torch.randint(0, graph.number_of_nodes(vtype), (len(src) * k,))
        return dgl.heterograph(
            {etype: (neg_src, neg_dst)},
            num_nodes_dict={ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes})

The model is a bit different from that in edge classification on
heterogeneous graphs since you need to specify edge type where you
perform link prediction.

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, rel_names):
            super().__init__()
            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
            self.pred = HeteroDotProductPredictor()
        def forward(self, g, neg_g, x, etype):
            h = self.sage(g, x)
            return self.pred(g, h, etype), self.pred(neg_g, h, etype)

The training loop is similar to that of homogeneous graphs.

.. code:: python

    def compute_loss(pos_score, neg_score):
        # Margin loss
        n_edges = pos_score.shape[0]
        return (1 - neg_score.view(n_edges, -1) + pos_score.unsqueeze(1)).clamp(min=0).mean()
    
    k = 5
    model = Model(10, 20, 5, hetero_graph.etypes)
    user_feats = hetero_graph.nodes['user'].data['feature']
    item_feats = hetero_graph.nodes['item'].data['feature']
    node_features = {'user': user_feats, 'item': item_feats}
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(10):
        negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
        pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
        loss = compute_loss(pos_score, neg_score)
        opt.zero_grad()
        loss.backward()
        opt.step()
        print(loss.item())


.. _guide-training-graph-classification:

Graph Classification
--------------------

Instead of a big single graph, sometimes we might have the data in the
form of multiple graphs, for example a list of different types of
communities of people. By characterizing the friendships among people in
the same community by a graph, we get a list of graphs to classify. In
this scenario, a graph classification model could help identify the type
of the community, i.e. to classify each graph based on the structure and
overall information.

Overview
~~~~~~~~

The major difference between graph classification and node
classification or link prediction is that the prediction result
characterize the property of the entire input graph. We perform the
message passing over nodes/edges just like the previous tasks, but also
try to retrieve a graph-level representation.

The graph classification proceeds as follows:

.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png
   :alt: Graph Classification Process

   Graph Classification Process

From left to right, the common practice is:

-  Prepare graphs in to a batch of graphs
-  Message passing on the batched graphs to update node/edge features
-  Aggregate node/edge features into a graph-level representation
-  Classification head for the task

Batch of Graphs
^^^^^^^^^^^^^^^

Usually a graph classification task trains on a lot of graphs, and it
will be very inefficient if we use only one graph at a time when
training the model. Borrowing the idea of mini-batch training from
common deep learning practice, we can build a batch of multiple graphs
and send them together for one training iteration.

In DGL, we can build a single batched graph of a list of graphs. This
batched graph can be simply used as a single large graph, with separated
components representing the corresponding original small graphs.

.. figure:: https://data.dgl.ai/tutorial/batch/batch.png
   :alt: Batched Graph

   Batched Graph

Graph Readout
^^^^^^^^^^^^^

Every graph in the data may have its unique structure, as well as its
node and edge features. In order to make a single prediction, we usually
aggregate and summarize over the possibly abundant information. This
type of operation is named *Readout*. Common aggregations include
summation, average, maximum or minimum over all node or edge features.

Given a graph :math:`g`, we can define the average readout aggregation
as

.. math:: h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v

In DGL the corresponding function call is :func:`dgl.readout_nodes`.

Once :math:`h_g` is available, we can pass it through an MLP layer for
classification output.

Writing neural network model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The input to the model is the batched graph with node and edge features.
One thing to note is the node and edge features in the batched graph
have no batch dimension. A little special care should be put in the
model:

Computation on a batched graph
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Next, we discuss the computational properties of a batched graph.

First, different graphs in a batch are entirely separated, i.e. no edge
connecting two graphs. With this nice property, all message passing
functions still have the same results.

Second, the readout function on a batched graph will be conducted over
each graph separately. Assume the batch size is :math:`B` and the
feature to be aggregated has dimension :math:`D`, the shape of the
readout result will be :math:`(B, D)`.

.. code:: python

    g1 = dgl.graph(([0, 1], [1, 0]))
    g1.ndata['h'] = torch.tensor([1., 2.])
    g2 = dgl.graph(([0, 1], [1, 2]))
    g2.ndata['h'] = torch.tensor([1., 2., 3.])
    
    dgl.readout_nodes(g1, 'h')
    # tensor([3.])  # 1 + 2
    
    bg = dgl.batch([g1, g2])
    dgl.readout_nodes(bg, 'h')
    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

Finally, each node/edge feature tensor on a batched graph is in the
format of concatenating the corresponding feature tensor from all
graphs.

.. code:: python

    bg.ndata['h']
    # tensor([1., 2., 1., 2., 3.])

Model definition
^^^^^^^^^^^^^^^^

Being aware of the above computation rules, we can define a very simple
model.

.. code:: python

    class Classifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes):
            super(Classifier, self).__init__()
            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
            self.classify = nn.Linear(hidden_dim, n_classes)
    
        def forward(self, g, feat):
            # Apply graph convolution and activation.
            h = F.relu(self.conv1(g, h))
            h = F.relu(self.conv2(g, h))
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = dgl.mean_nodes(g, 'h')
                return self.classify(hg)

Training loop
~~~~~~~~~~~~~

Data Loading
^^^^^^^^^^^^

Once the models defined, we can start training. Since graph
classification deals with lots of relative small graphs instead of a big
single one, we usually can train efficiently on stochastic mini-batches
of graphs, without the need to design sophisticated graph sampling
algorithms.

Assuming that we have a graph classification dataset as introduced in
:ref:`guide-data-pipeline`.

.. code:: python

    import dgl.data
    dataset = dgl.data.GINDataset('MUTAG', False)

Each item in the graph classification dataset is a pair of a graph and
its label. We can speed up the data loading process by taking advantage
of the DataLoader, by customizing the collate function to batch the
graphs:

.. code:: python

    def collate(samples):
        graphs, labels = map(list, zip(*samples))
        batched_graph = dgl.batch(graphs)
        batched_labels = torch.tensor(labels)
        return batched_graph, batched_labels

Then one can create a DataLoader that iterates over the dataset of
graphs in minibatches.

.. code:: python

    from torch.utils.data import DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=1024,
        collate_fn=collate,
        drop_last=False,
        shuffle=True)

Loop
^^^^

Training loop then simply involves iterating over the dataloader and
updating the model.

.. code:: python

    model = Classifier(10, 20, 5)
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
            feats = batched_graph.ndata['feats']
            logits = model(batched_graph, feats)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

DGL implements
`GIN <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__
as an example of graph classification. The training loop is inside the
function ``train`` in
```main.py`` <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__.
The model implementation is inside
```gin.py`` <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__
with more components such as using
:class:`dgl.nn.pytorch.GINConv` (also available in MXNet and Tensorflow)
as the graph convolution layer, batch normalization, etc.

Heterogeneous graph
~~~~~~~~~~~~~~~~~~~

Graph classification with heterogeneous graphs is a little different
from that with homogeneous graphs. Except that you need heterogeneous
graph convolution modules, yoyu also need to aggregate over the nodes of
different types in the readout function.

The following shows an example of summing up the average of node
representations for each node type.

.. code:: python

    class RGCN(nn.Module):
        def __init__(self, in_feats, hid_feats, out_feats, rel_names):
            super().__init__()
            
            self.conv1 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, hid_feats)
                for rel in rel_names}, aggregate='sum')
            self.conv2 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(hid_feats, out_feats)
                for rel in rel_names}, aggregate='sum')
      
        def forward(self, graph, inputs):
            # inputs are features of nodes
            h = self.conv1(graph, inputs)
            h = {k: F.relu(v) for k, v in h.items()}
            h = self.conv2(graph, h)
            return h
    
    class HeteroClassifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
            super().__init__()
            
            self.conv1 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, hid_feats)
                for rel in rel_names}, aggregate='sum')
            self.conv2 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(hid_feats, out_feats)
                for rel in rel_names}, aggregate='sum')
            self.classify = nn.Linear(hidden_dim, n_classes)
    
        def forward(self, g):
            h = g.ndata['feat']
            # Apply graph convolution and activation.
            h = F.relu(self.conv1(g, h))
            h = F.relu(self.conv2(g, h))
    
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
                return self.classify(hg)

The rest of the code is not different from that for homogeneous graphs.

.. code:: python

    model = HeteroClassifier(10, 20, 5)
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
            logits = model(batched_graph)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()