train.py 4.64 KB
Newer Older
paoxiaode's avatar
paoxiaode committed
1
2
3
"""
Gated Graph Convolutional Network module for graph classification tasks
"""
4
5
import argparse
import time
paoxiaode's avatar
paoxiaode committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import GatedGCNConv
from dgl.nn.pytorch.glob import AvgPooling
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder


class GatedGCN(nn.Module):
    def __init__(
        self,
        hid_dim,
        out_dim,
        num_layers,
        dropout=0.2,
        batch_norm=True,
        residual=True,
        activation=F.relu,
    ):
        super(GatedGCN, self).__init__()

        self.num_layers = num_layers
        self.dropout = dropout

        self.node_encoder = AtomEncoder(hid_dim)
        self.edge_encoder = BondEncoder(hid_dim)

        self.layers = nn.ModuleList()
        for _ in range(self.num_layers):
            layer = GatedGCNConv(
                input_feats=hid_dim,
                edge_feats=hid_dim,
                output_feats=hid_dim,
                dropout=dropout,
                batch_norm=batch_norm,
                residual=residual,
                activation=activation,
            )
            self.layers.append(layer)

        self.pooling = AvgPooling()
        self.output = nn.Linear(hid_dim, out_dim)

    def forward(self, g, node_feat, edge_feat):
        # Encode node and edge feature.
        hv = self.node_encoder(node_feat)
        he = self.edge_encoder(edge_feat)

        # GatedGCNConv layers.
        for layer in self.layers:
            hv, he = layer(g, hv, he)

        # Output project.
        h_g = self.pooling(g, hv)

        return self.output(h_g)


def train(model, device, data_loader, opt, loss_fn):
    model.train()
    train_loss = []

    for g, labels in data_loader:
        g = g.to(device)
        labels = labels.to(torch.float32).to(device)
        logits = model(g, g.ndata["feat"], g.edata["feat"])
        loss = loss_fn(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss.append(loss.item())

    return sum(train_loss) / len(train_loss)


@torch.no_grad()
def evaluate(model, device, data_loader, evaluator):
    model.eval()
    y_true, y_pred = [], []

    for g, labels in data_loader:
        g = g.to(device)
        logits = model(g, g.ndata["feat"], g.edata["feat"])
        y_true.append(labels.detach().cpu())
        y_pred.append(logits.detach().cpu())

    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

    return evaluator.eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="ogbg-molhiv",
        help="Dataset name ('ogbg-molhiv', 'ogbg-molbace', 'ogbg-molmuv').",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=200,
        help="Number of epochs for train.",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=0,
        help="Number of GPUs used for train and evaluation.",
    )
    args = parser.parse_args()
    print("Training with DGL built-in GATConv module.")
paoxiaode's avatar
paoxiaode committed
125
126

    # Load ogb dataset & evaluator.
127
128
129
130
131
132
133
    dataset = DglGraphPropPredDataset(name=args.dataset)
    evaluator = Evaluator(name=args.dataset)

    if args.num_gpus > 0 and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
paoxiaode's avatar
paoxiaode committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    n_classes = dataset.num_tasks

    split_idx = dataset.get_idx_split()
    train_loader = GraphDataLoader(
        dataset[split_idx["train"]],
        batch_size=32,
        shuffle=True,
    )
    valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32)
    test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32)

    # Load model.
    model = GatedGCN(hid_dim=256, out_dim=n_classes, num_layers=8).to(device)

    print(model)

    opt = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.BCEWithLogitsLoss()

    print("---------- Training ----------")
155
    for epoch in range(args.num_epochs):
paoxiaode's avatar
paoxiaode committed
156
        # Kick off training.
157
        t0 = time.time()
paoxiaode's avatar
paoxiaode committed
158
        loss = train(model, device, train_loader, opt, loss_fn)
159
        t1 = time.time()
paoxiaode's avatar
paoxiaode committed
160
        # Evaluate the prediction.
161
        val_acc = evaluate(model, device, valid_loader, evaluator)
paoxiaode's avatar
paoxiaode committed
162
        print(
163
164
            f"Epoch {epoch:05d} | Loss {loss:.4f} | Accuracy {val_acc:.4f} | "
            f"Time {t1 - t0:.4f}"
paoxiaode's avatar
paoxiaode committed
165
        )
166
167
    acc = evaluate(model, device, test_loader, evaluator)
    print(f"Test accuracy {acc:.4f}")