9_gat.py 20.3 KB
Newer Older
Hao Zhang's avatar
Hao Zhang committed
1
2
3
"""
.. _model-gat:

4
5
Understand Graph Attention Network
=======================================
Hao Zhang's avatar
Hao Zhang committed
6
7
8
9
10
11

**Authors:** `Hao Zhang <https://github.com/sufeidechabei/>`_, `Mufei Li
<https://github.com/mufeili>`_, `Minjie Wang
<https://jermainewang.github.io/>`_  `Zheng Zhang
<https://shanghai.nyu.edu/academics/faculty/directory/zheng-zhang>`_

12
13
14
15
16
17
18
.. warning::

    The tutorial aims at gaining insights into the paper, with code as a mean
    of explanation. The implementation thus is NOT optimized for running
    efficiency. For recommended implementation, please refer to the `official
    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.

19
20
21
In this tutorial, you learn about a graph attention network (GAT) and how it can be 
implemented in PyTorch. You can also learn to visualize and understand what the attention 
mechanism has learned.
Hao Zhang's avatar
Hao Zhang committed
22

23
24
25
26
The research described in the paper `Graph Convolutional Network (GCN) <https://arxiv.org/abs/1609.02907>`_,
indicates that combining local graph structure and node-level features yields
good performance on node classification tasks. However, the way GCN aggregates
is structure-dependent, which can hurt its generalizability.
Hao Zhang's avatar
Hao Zhang committed
27

28
29
30
31
32
33
One workaround is to simply average over all neighbor node features as described in
the research paper `GraphSAGE
<https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`_.
However, `Graph Attention Network <https://arxiv.org/abs/1710.10903>`_ proposes a
different type of aggregation. GAN uses weighting neighbor features with feature dependent and
structure-free normalization, in the style of attention.
Hao Zhang's avatar
Hao Zhang committed
34
35
"""
###############################################################
36
# Introducing attention to GCN
Hao Zhang's avatar
Hao Zhang committed
37
38
39
40
# ----------------------------
#
# The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.
#
41
# For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors.
Hao Zhang's avatar
Hao Zhang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#
#
# .. math::
#
#   h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right)
#
#
# where :math:`\mathcal{N}(i)` is the set of its one-hop neighbors (to include
# :math:`v_i` in the set, simply add a self-loop to each node),
# :math:`c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}` is a
# normalization constant based on graph structure, :math:`\sigma` is an
# activation function (GCN uses ReLU), and :math:`W^{(l)}` is a shared
# weight matrix for node-wise feature transformation. Another model proposed in
# `GraphSAGE
# <https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`_
# employs the same update rule except that they set
# :math:`c_{ij}=|\mathcal{N}(i)|`.
#
# GAT introduces the attention mechanism as a substitute for the statically
# normalized convolution operation. Below are the equations to compute the node
# embedding :math:`h_i^{(l+1)}` of layer :math:`l+1` from the embeddings of
63
# layer :math:`l`.
Hao Zhang's avatar
Hao Zhang committed
64
#
Jinjing Zhou's avatar
Jinjing Zhou committed
65
# .. image:: https://data.dgl.ai/tutorial/gat/gat.png
Hao Zhang's avatar
Hao Zhang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#   :width: 450px
#   :align: center
#
# .. math::
#
#   \begin{align}
#   z_i^{(l)}&=W^{(l)}h_i^{(l)},&(1) \\
#   e_{ij}^{(l)}&=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),&(2)\\
#   \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
#   h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
#   \end{align}
#
#
# Explanations:
#
#
# * Equation (1) is a linear transformation of the lower layer embedding :math:`h_i^{(l)}`
#   and :math:`W^{(l)}` is its learnable weight matrix.
84
# * Equation (2) computes a pair-wise *un-normalized* attention score between two neighbors.
Hao Zhang's avatar
Hao Zhang committed
85
86
87
88
89
90
#   Here, it first concatenates the :math:`z` embeddings of the two nodes, where :math:`||`
#   denotes concatenation, then takes a dot product of it and a learnable weight vector
#   :math:`\vec a^{(l)}`, and applies a LeakyReLU in the end. This form of attention is
#   usually called *additive attention*, contrast with the dot-product attention in the
#   Transformer model.
# * Equation (3) applies a softmax to normalize the attention scores on each node's
91
#   incoming edges.
Hao Zhang's avatar
Hao Zhang committed
92
93
94
95
# * Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together,
#   scaled by the attention scores.
#
# There are other details from the paper, such as dropout and skip connections.
96
97
# For the purpose of simplicity, those details are left out of this tutorial. To see more details, 
# download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
Hao Zhang's avatar
Hao Zhang committed
98
99
100
101
102
103
# In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation.
#
# GAT in DGL
# ----------
#
Minjie Wang's avatar
Minjie Wang committed
104
105
106
107
108
109
110
111
112
# DGL provides an off-the-shelf implementation of the GAT layer under the ``dgl.nn.<backend>``
# subpackage. Simply import the ``GATConv`` as the follows.

from dgl.nn.pytorch import GATConv

###############################################################
# Readers can skip the following step-by-step explanation of the implementation and
# jump to the `Put everything together`_ for training and visualization results.
#
113
114
115
# To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. In this section, the four equations above are broken down 
# one at a time.
116
117
118
119
120
121
#
# .. note::
#
#    This is showing how to implement a GAT from scratch.  DGL provides a more
#    efficient :class:`builtin GAT layer module <dgl.nn.pytorch.conv.GATConv>`.
#
Hao Zhang's avatar
Hao Zhang committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135

import torch
import torch.nn as nn
import torch.nn.functional as F


class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
136
137
138
139
140
141
142
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
Hao Zhang's avatar
Hao Zhang committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

##################################################################
# Equation (1)
# ^^^^^^^^^^^^
#
# .. math::
#
#   z_i^{(l)}=W^{(l)}h_i^{(l)},(1)
#
180
# The first one shows linear transformation. It's common and can be
Hao Zhang's avatar
Hao Zhang committed
181
182
183
184
185
186
187
188
189
# easily implemented in Pytorch using ``torch.nn.Linear``.
#
# Equation (2)
# ^^^^^^^^^^^^
#
# .. math::
#
#   e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2)
#
190
# The un-normalized attention score :math:`e_{ij}` is calculated using the
Hao Zhang's avatar
Hao Zhang committed
191
# embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the
192
# attention scores can be viewed as edge data, which can be calculated by the
Hao Zhang's avatar
Hao Zhang committed
193
194
195
196
197
198
199
200
201
202
203
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below:

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e' : F.leaky_relu(a)}

########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
204
# is implemented again using PyTorch's linear transformation ``attn_fc``. Note
Hao Zhang's avatar
Hao Zhang committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# that ``apply_edges`` will **batch** all the edge data in one tensor, so the
# ``cat``, ``attn_fc`` here are applied on all the edges in parallel.
#
# Equation (3) & (4)
# ^^^^^^^^^^^^^^^^^^
#
# .. math::
#
#   \begin{align}
#   \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
#   h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
#   \end{align}
#
# Similar to GCN, ``update_all`` API is used to trigger message passing on all
# the nodes. The message function sends out two tensors: the transformed ``z``
220
# embedding of the source node and the un-normalized attention score ``e`` on
Hao Zhang's avatar
Hao Zhang committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# each edge. The reduce function then performs two tasks:
#
#
# * Normalize the attention scores using softmax (equation (3)).
# * Aggregate neighbor embeddings weighted by the attention scores (equation(4)).
#
# Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched.

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h' : h}

#####################################################################
239
# Multi-head attention
Hao Zhang's avatar
Hao Zhang committed
240
241
242
243
244
245
246
247
248
249
250
251
252
# ^^^^^^^^^^^^^^^^^^^^
#
# Analogous to multiple channels in ConvNet, GAT introduces **multi-head
# attention** to enrich the model capacity and to stabilize the learning
# process. Each attention head has its own parameters and their outputs can be
# merged in two ways:
#
# .. math:: \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# or
#
# .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
253
# where :math:`K` is the number of heads. You can use
Hao Zhang's avatar
Hao Zhang committed
254
255
# concatenation for intermediary layers and average for the final layer.
#
256
# Use the above defined single-head ``GATLayer`` as the building block
Hao Zhang's avatar
Hao Zhang committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# for the ``MultiHeadGATLayer`` below:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
280
# Now, you can define a two-layer GAT model.
Hao Zhang's avatar
Hao Zhang committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

#############################################################################
298
# We then load the Cora dataset using DGL's built-in data module.
Hao Zhang's avatar
Hao Zhang committed
299
300
301

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
302
import networkx as nx
Hao Zhang's avatar
Hao Zhang committed
303
304
305
306
307

def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
Mufei Li's avatar
Mufei Li committed
308
    mask = torch.BoolTensor(data.train_mask)
Minjie Wang's avatar
Minjie Wang committed
309
    g = DGLGraph(data.graph)
Hao Zhang's avatar
Hao Zhang committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    return g, features, labels, mask

##############################################################################
# The training loop is exactly the same as in the GCN tutorial.

import time
import numpy as np

g, features, labels, mask = load_cora_data()

# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))

#########################################################################
351
# Visualizing and understanding attention learned
Hao Zhang's avatar
Hao Zhang committed
352
353
354
355
356
# ----------------------------------------------
#
# Cora
# ^^^^
#
357
358
# The following table summarizes the model performance on Cora that is reported in
# `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ and obtained with DGL 
Hao Zhang's avatar
Hao Zhang committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
# implementations.
#
# .. list-table::
#    :header-rows: 1
#
#    * - Model
#      - Accuracy
#    * - GCN (paper)
#      - :math:`81.4\pm 0.5%`
#    * - GCN (dgl)
#      - :math:`82.05\pm 0.33%`
#    * - GAT (paper)
#      - :math:`83.0\pm 0.7%`
#    * - GAT (dgl)
#      - :math:`83.69\pm 0.529%`
#
375
# *What kind of attention distribution has our model learned?*
Hao Zhang's avatar
Hao Zhang committed
376
#
377
378
# Because the attention weight :math:`a_{ij}` is associated with edges, you can
# visualize it by coloring edges. Below you can pick a subgraph of Cora and plot the
Hao Zhang's avatar
Hao Zhang committed
379
380
381
382
# attention weights of the last ``GATLayer``. The nodes are colored according
# to their labels, whereas the edges are colored according to the magnitude of
# the attention weights, which can be referred with the colorbar on the right.
#
Jinjing Zhou's avatar
Jinjing Zhou committed
383
# .. image:: https://data.dgl.ai/tutorial/gat/cora-attention.png
Hao Zhang's avatar
Hao Zhang committed
384
385
386
#   :width: 600px
#   :align: center
#
387
388
# You can see that the model seems to learn different attention weights. To
# understand the distribution more thoroughly, measure the `entropy
Hao Zhang's avatar
Hao Zhang committed
389
390
391
392
393
394
395
# <https://en.wikipedia.org/wiki/Entropy_(information_theory>`_) of the
# attention distribution. For any node :math:`i`,
# :math:`\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}` forms a discrete probability
# distribution over all its neighbors with the entropy given by
#
# .. math:: H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij}
#
396
397
# A low entropy means a high degree of concentration, and vice
# versa. An entropy of 0 means all attention is on one source node. The uniform
Hao Zhang's avatar
Hao Zhang committed
398
# distribution has the highest entropy of :math:`\log(\mathcal{N}(i))`.
399
# Ideally, you want to see the model learns a distribution of lower entropy
Hao Zhang's avatar
Hao Zhang committed
400
401
402
# (i.e, one or two neighbors are much more important than the others).
#
# Note that since nodes can have different degrees, the maximum entropy will
403
# also be different. Therefore, you plot the aggregated histogram of entropy
Hao Zhang's avatar
Hao Zhang committed
404
405
406
407
408
409
410
# values of all nodes in the entire graph. Below are the attention histogram of
# learned by each attention head.
#
# |image2|
#
# As a reference, here is the histogram if all the nodes have uniform attention weight distribution.
#
Jinjing Zhou's avatar
Jinjing Zhou committed
411
# .. image:: https://data.dgl.ai/tutorial/gat/cora-attention-uniform-hist.png
Hao Zhang's avatar
Hao Zhang committed
412
413
414
415
416
417
418
419
#   :width: 250px
#   :align: center
#
# One can see that **the attention values learned is quite similar to uniform distribution**
# (i.e, all neighbors are equally important). This partially
# explains why the performance of GAT is close to that of GCN on Cora
# (according to `author's reported result
# <https://arxiv.org/pdf/1710.10903.pdf>`_, the accuracy difference averaged
420
421
# over 100 runs is less than 2 percent). Attention does not matter
# since it does not differentiate much.
Hao Zhang's avatar
Hao Zhang committed
422
423
#
# *Does that mean the attention mechanism is not useful?* No! A different
424
# dataset exhibits an entirely different pattern, as you can see next.
Hao Zhang's avatar
Hao Zhang committed
425
#
426
# Protein-protein interaction (PPI) networks
Hao Zhang's avatar
Hao Zhang committed
427
428
429
430
431
432
433
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The PPI dataset used here consists of :math:`24` graphs corresponding to
# different human tissues. Nodes can have up to :math:`121` kinds of labels, so
# the label of node is represented as a binary tensor of size :math:`121`. The
# task is to predict node label.
#
434
# Use :math:`20` graphs for training, :math:`2` for validation and :math:`2`
Hao Zhang's avatar
Hao Zhang committed
435
436
# for test. The average number of nodes per graph is :math:`2372`. Each node
# has :math:`50` features that are composed of positional gene sets, motif gene
437
# sets, and immunological signatures. Critically, test graphs remain completely
Hao Zhang's avatar
Hao Zhang committed
438
439
# unobserved during training, a setting called "inductive learning".
#
440
# Compare the performance of GAT and GCN for :math:`10` random runs on this
Hao Zhang's avatar
Hao Zhang committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# task and use hyperparameter search on the validation set to find the best
# model.
#
# .. list-table::
#    :header-rows: 1
#
#    * - Model
#      - F1 Score(micro)
#    * - GAT
#      - :math:`0.975 \pm 0.006`
#    * - GCN
#      - :math:`0.509 \pm 0.025`
#    * - Paper
#      - :math:`0.973 \pm 0.002`
#
456
# The table above is the result of this experiment, where you use micro `F1
Hao Zhang's avatar
Hao Zhang committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# score <https://en.wikipedia.org/wiki/F1_score>`_ to evaluate the model
# performance.
#
# .. note::
#
#   Below is the calculation process of F1 score:
#  
#   .. math::
#  
#      precision=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FP_{t})}
#  
#      recall=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FN_{t})}
#  
#      F1_{micro}=2\frac{precision*recall}{precision+recall}
#  
#   * :math:`TP_{t}` represents for number of nodes that both have and are predicted to have label :math:`t`
#   * :math:`FP_{t}` represents for number of nodes that do not have but are predicted to have label :math:`t`
#   * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others.
#   * :math:`n` is the number of labels, i.e. :math:`121` in our case.
#
477
# During training, use ``BCEWithLogitsLoss`` as the loss function. The
Hao Zhang's avatar
Hao Zhang committed
478
479
480
# learning curves of GAT and GCN are presented below; what is evident is the
# dramatic performance adavantage of GAT over GCN.
#
Jinjing Zhou's avatar
Jinjing Zhou committed
481
# .. image:: https://data.dgl.ai/tutorial/gat/ppi-curve.png
Hao Zhang's avatar
Hao Zhang committed
482
483
484
#   :width: 300px
#   :align: center
#
485
# As before, you can have a statistical understanding of the attentions learned
Hao Zhang's avatar
Hao Zhang committed
486
# by showing the histogram plot for the node-wise attention entropy. Below are
487
# the attention histograms learned by different attention layers.
Hao Zhang's avatar
Hao Zhang committed
488
#
489
# *Attention learned in layer 1:*
Hao Zhang's avatar
Hao Zhang committed
490
491
492
#
# |image5|
#
493
# *Attention learned in layer 2:*
Hao Zhang's avatar
Hao Zhang committed
494
495
496
#
# |image6|
#
497
# *Attention learned in final layer:*
Hao Zhang's avatar
Hao Zhang committed
498
499
500
501
502
#
# |image7|
#
# Again, comparing with uniform distribution: 
#
Jinjing Zhou's avatar
Jinjing Zhou committed
503
# .. image:: https://data.dgl.ai/tutorial/gat/ppi-uniform-hist.png
Hao Zhang's avatar
Hao Zhang committed
504
505
506
507
#   :width: 250px
#   :align: center
#
# Clearly, **GAT does learn sharp attention weights**! There is a clear pattern
508
# over the layers as well: **the attention gets sharper with a higher
Mufei Li's avatar
Mufei Li committed
509
# layer**.
Hao Zhang's avatar
Hao Zhang committed
510
#
511
# Unlike the Cora dataset where GAT's gain is minimal at best, for PPI there
Hao Zhang's avatar
Hao Zhang committed
512
# is a significant performance gap between GAT and other GNN variants compared
513
# in `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ (at least 20 percent),
Hao Zhang's avatar
Hao Zhang committed
514
515
516
517
518
# and the attention distributions between the two clearly differ. While this
# deserves further research, one immediate conclusion is that GAT's advantage
# lies perhaps more in its ability to handle a graph with more complex
# neighborhood structure.
#
519
# What's next?
Hao Zhang's avatar
Hao Zhang committed
520
521
# ------------
#
522
523
524
525
# So far, you have seen how to use DGL to implement GAT. There are some
# missing details such as dropout, skip connections, and hyper-parameter tuning,
# which are practices that do not involve DGL-related concepts. For more information
# check out the full example.
Hao Zhang's avatar
Hao Zhang committed
526
#
527
528
# * See the optimized `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
# * The next tutorial describes how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization.
Hao Zhang's avatar
Hao Zhang committed
529
#
Jinjing Zhou's avatar
Jinjing Zhou committed
530
531
532
533
# .. |image2| image:: https://data.dgl.ai/tutorial/gat/cora-attention-hist.png
# .. |image5| image:: https://data.dgl.ai/tutorial/gat/ppi-first-layer-hist.png
# .. |image6| image:: https://data.dgl.ai/tutorial/gat/ppi-second-layer-hist.png
# .. |image7| image:: https://data.dgl.ai/tutorial/gat/ppi-final-layer-hist.png