train.py 4.28 KB
Newer Older
czkkkkkk's avatar
czkkkkkk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop
from torch.nn import init

from dgl.mock_sparse import create_from_coo, diag, identity


class GraphConv(nn.Module):
    def __init__(self, in_size, out_size, activation=None):
        super(GraphConv, self).__init__()
        self.W = nn.Parameter(torch.Tensor(in_size, out_size))
        self.activation = activation
        self.bias = nn.Parameter(torch.Tensor(out_size))

        self.reset_parameters()

    def forward(self, A, x):
        h = x @ self.W  # Dense mm, pytorch op
        h = A @ h       # SpMM
        h += self.bias

        if self.activation:
            h = self.activation(h)
        return h

    def reset_parameters(self):
        init.xavier_uniform_(self.W)
        init.zeros_(self.bias)


class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # two-layer GCN
        self.layers.append(GraphConv(in_size, hid_size, activation=F.relu))
        self.layers.append(GraphConv(hid_size, out_size))
        self.dropout = nn.Dropout(0.5)

    def forward(self, A, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(A, h)
        return h


def gcn_norm(A):
    # normalization
    I = identity(A.shape)  # create an identity matrix
    A_hat = A + I  # add self-loop to A
    D = diag(A_hat.sum(0))  # diagonal degree matrix of A_hat
    # FIXME DiagMatrix does not have power() method
    D_hat = D
    D_hat.val = D_hat.val**-0.5
    A_hat = D_hat @ A_hat @ D_hat

    return A_hat


def evaluate(A, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(A, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def train(A, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

    # training loop
    for epoch in range(200):
        model.train()
        logits = model(A, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(A, features, labels, val_mask, model)
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed', 'synthetic).",
    )
    args = parser.parse_args()
    print(f"Training with DGL SparseMatrix GraphConv module.")

    # load and preprocess dataset
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
        data = CoraGraphDataset(transform=transform)
    elif args.dataset == "citeseer":
        data = CiteseerGraphDataset(transform=transform)
    elif args.dataset == "pubmed":
        data = PubmedGraphDataset(transform=transform)
    else:
        raise ValueError("Unknown dataset: {}".format(args.dataset))
    g = data[0].int()
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

    row, col = g.adj_sparse("coo")
    A = create_from_coo(
        row, col, shape=(g.number_of_nodes(), g.number_of_nodes())
    )
    A = gcn_norm(A)

    # create GCN model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = GCN(in_size, 16, out_size)

    # model training
    print("Training...")
    train(A, features, labels, masks, model)

    # test the model
    print("Testing...")
    acc = evaluate(A, features, labels, masks[2], model)
    print("Test accuracy {:.4f}".format(acc))