train.py 3.64 KB
Newer Older
1
2
import argparse

3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6

7
8
import dgl.nn as dglnn
from dgl import AddSelfLoop
9
10
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

11

12
class GAT(nn.Module):
13
    def __init__(self, in_size, hid_size, out_size, heads):
14
15
16
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # two-layer GAT
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        self.gat_layers.append(
            dglnn.GATConv(
                in_size,
                hid_size,
                heads[0],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=F.elu,
            )
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[0],
                out_size,
                heads[1],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=None,
            )
        )

38
39
40
41
    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
42
            if i == 1:  # last layer
43
                h = h.mean(1)
44
            else:  # other layer(s)
45
46
                h = h.flatten(1)
        return h
47
48


49
def evaluate(g, features, labels, mask, model):
50
51
    model.eval()
    with torch.no_grad():
52
        logits = model(g, features)
53
54
        logits = logits[mask]
        labels = labels[mask]
55
56
57
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)
58
59


60
61
62
63
64
65
66
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

67
    # training loop
68
69
70
71
72
73
74
75
    for epoch in range(200):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
76
77
78
79
80
81
82
83
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
84
    parser = argparse.ArgumentParser()
85
86
87
88
89
90
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed').",
    )
91
    args = parser.parse_args()
92
93
    print(f"Training with DGL built-in GATConv module.")

94
    # load and preprocess dataset
95
96
97
98
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
99
        data = CoraGraphDataset(transform=transform)
100
    elif args.dataset == "citeseer":
101
        data = CiteseerGraphDataset(transform=transform)
102
    elif args.dataset == "pubmed":
103
        data = PubmedGraphDataset(transform=transform)
104
    else:
105
        raise ValueError("Unknown dataset: {}".format(args.dataset))
106
    g = data[0]
107
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
    g = g.int().to(device)
109
110
111
112
113
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

    # create GAT model
114
115
    in_size = features.shape[1]
    out_size = data.num_classes
116
117
    model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)

118
    # model training
119
    print("Training...")
120
121
122
    train(g, features, labels, masks, model)

    # test the model
123
    print("Testing...")
124
125
    acc = evaluate(g, features, labels, masks[2], model)
    print("Test accuracy {:.4f}".format(acc))