hypergraphatt.py 4.36 KB
Newer Older
1
2
3
4
"""
Hypergraph Convolution and Hypergraph Attention
(https://arxiv.org/pdf/1901.08150.pdf).
"""
5
6
import argparse

7
import dgl.sparse as dglsp
8

9
10
11
12
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
13
14
15
from dgl.data import CoraGraphDataset
from torchmetrics.functional import accuracy

16
17
18
19
20
21
22
23

def hypergraph_laplacian(H):
    ###########################################################
    # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API
    ###########################################################
    d_V = H.sum(1)  # node degree
    d_E = H.sum(0)  # edge degree
    n_edges = d_E.shape[0]
24
25
    D_V_invsqrt = dglsp.diag(d_V**-0.5)  # D_V ** (-1/2)
    D_E_inv = dglsp.diag(d_E**-1)  # D_E ** (-1)
26
27
28
    W = dglsp.identity((n_edges, n_edges))
    return D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt

29

30
31
32
33
34
class HypergraphAttention(nn.Module):
    """Hypergraph Attention module as in the paper
    `Hypergraph Convolution and Hypergraph Attention
    <https://arxiv.org/pdf/1901.08150.pdf>`_.
    """
35

36
37
38
39
40
41
42
43
44
45
46
47
    def __init__(self, in_size, out_size):
        super().__init__()

        self.P = nn.Linear(in_size, out_size)
        self.a = nn.Linear(2 * out_size, 1)

    def forward(self, H, X, X_edges):
        Z = self.P(X)
        Z_edges = self.P(X_edges)
        sim = self.a(torch.cat([Z[H.row], Z_edges[H.col]], 1))
        sim = F.leaky_relu(sim, 0.2).squeeze(1)
        # Reassign the hypergraph new weights.
48
        H_att = dglsp.val_like(H, sim)
49
50
51
        H_att = H_att.softmax()
        return hypergraph_laplacian(H_att) @ Z

52

53
54
55
56
57
58
59
60
61
62
63
64
65
class Net(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=16):
        super().__init__()

        self.layer1 = HypergraphAttention(in_size, hidden_size)
        self.layer2 = HypergraphAttention(hidden_size, out_size)

    def forward(self, H, X):
        Z = self.layer1(H, X, X)
        Z = F.elu(Z)
        Z = self.layer2(H, Z, Z)
        return Z

66

67
68
69
70
71
72
73
74
75
def train(model, optimizer, H, X, Y, train_mask):
    model.train()
    Y_hat = model(H, X)
    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

76
77

def evaluate(model, H, X, Y, val_mask, test_mask, num_classes):
78
79
    model.eval()
    Y_hat = model(H, X)
80
81
82
83
84
85
86
87
88
    val_acc = accuracy(
        Y_hat[val_mask], Y[val_mask], task="multiclass", num_classes=num_classes
    )
    test_acc = accuracy(
        Y_hat[test_mask],
        Y[test_mask],
        task="multiclass",
        num_classes=num_classes,
    )
89
90
    return val_acc, test_acc

91

92
def load_data():
93
    dataset = CoraGraphDataset()
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    graph = dataset[0]
    # The paper created a hypergraph from the original graph. For each node in
    # the original graph, a hyperedge in the hypergraph is created to connect
    # its neighbors and itself. In this case, the incidence matrix of the
    # hypergraph is the same as the adjacency matrix of the original graph (with
    # self-loops).
    # We follow the paper and assume that the rows of the incidence matrix
    # are for nodes and the columns are for edges.
    src, dst = graph.edges()
    H = dglsp.create_from_coo(dst, src)
    H = H + dglsp.identity(H.shape)

    X = graph.ndata["feat"]
    Y = graph.ndata["label"]
    train_mask = graph.ndata["train_mask"]
    val_mask = graph.ndata["val_mask"]
    test_mask = graph.ndata["test_mask"]
    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask

114
115

def main(args):
116
117
118
119
    H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
    model = Net(X.shape[1], num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

120
    with tqdm.trange(args.epochs) as tq:
121
122
        for epoch in tq:
            loss = train(model, optimizer, H, X, Y, train_mask)
123
124
125
            val_acc, test_acc = evaluate(
                model, H, X, Y, val_mask, test_mask, num_classes
            )
126
127
128
129
130
131
132
133
134
            tq.set_postfix(
                {
                    "Loss": f"{loss:.5f}",
                    "Val acc": f"{val_acc:.5f}",
                    "Test acc": f"{test_acc:.5f}",
                },
                refresh=False,
            )

135
136
137
138
139
140
141
142
143
144
    print(f"Test acc: {test_acc:.3f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hypergraph Attention Example")
    parser.add_argument(
        "--epochs", type=int, default=500, help="Number of training epochs."
    )
    args = parser.parse_args()
    main(args)