main.py 5.06 KB
Newer Older
xnouhz's avatar
xnouhz committed
1
2
3
4
""" The main file to train a JKNet model using a full graph """

import argparse
import copy
5
6

import numpy as np
xnouhz's avatar
xnouhz committed
7
8
import torch
import torch.nn as nn
9
import torch.optim as optim
xnouhz's avatar
xnouhz committed
10
from model import JKNet
11
12
13
14
15
from sklearn.model_selection import train_test_split
from tqdm import trange

from dgl.data import CiteseerGraphDataset, CoraGraphDataset

xnouhz's avatar
xnouhz committed
16
17
18
19

def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
20
    if args.dataset == "Cora":
xnouhz's avatar
xnouhz committed
21
        dataset = CoraGraphDataset()
22
    elif args.dataset == "Citeseer":
xnouhz's avatar
xnouhz committed
23
24
        dataset = CiteseerGraphDataset()
    else:
25
26
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

xnouhz's avatar
xnouhz committed
27
28
29
    graph = dataset[0]

    # check cuda
30
31
32
33
34
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )
xnouhz's avatar
xnouhz committed
35
36
37
38
39

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
40
    labels = graph.ndata.pop("label").to(device).long()
xnouhz's avatar
xnouhz committed
41
42

    # Extract node features
43
    feats = graph.ndata.pop("feat").to(device)
xnouhz's avatar
xnouhz committed
44
45
46
47
48
49
50
51
52
53
54
55
    n_features = feats.shape[-1]

    # create masks for train / validation / test
    # train : val : test = 6 : 2 : 2
    n_nodes = graph.num_nodes()
    idx = torch.arange(n_nodes).to(device)
    train_idx, test_idx = train_test_split(idx, test_size=0.2)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.25)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
56
57
58
59
60
61
62
63
64
    model = JKNet(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_layers=args.num_layers,
        mode=args.mode,
        dropout=args.dropout,
    ).to(device)

xnouhz's avatar
xnouhz committed
65
66
67
68
69
70
71
72
    best_model = copy.deepcopy(model)

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)

    # Step 4: training epochs =============================================================== #
    acc = 0
73
    epochs = trange(args.epochs, desc="Accuracy & Loss")
xnouhz's avatar
xnouhz committed
74
75
76
77
78
79
80
81
82

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
83
84
85
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
xnouhz's avatar
xnouhz committed
86
87
88
89
90
91
92
93
94
95
96

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
97
98
99
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
xnouhz's avatar
xnouhz committed
100
101

        # Print out performance
102
103
104
105
106
107
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

xnouhz's avatar
xnouhz committed
108
109
110
111
112
113
        if valid_acc > acc:
            acc = valid_acc
            best_model = copy.deepcopy(model)

    best_model.eval()
    logits = best_model(graph, feats)
114
115
116
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
xnouhz's avatar
xnouhz committed
117
118
119
120

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc

121

xnouhz's avatar
xnouhz committed
122
123
124
125
if __name__ == "__main__":
    """
    JKNet Hyperparameters
    """
126
    parser = argparse.ArgumentParser(description="JKNet")
xnouhz's avatar
xnouhz committed
127
128

    # data source params
129
130
131
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
xnouhz's avatar
xnouhz committed
132
    # cuda params
133
134
135
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
xnouhz's avatar
xnouhz committed
136
    # training params
137
138
139
140
141
142
143
144
    parser.add_argument("--run", type=int, default=10, help="Running times.")
    parser.add_argument(
        "--epochs", type=int, default=500, help="Training epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=0.005, help="Learning rate."
    )
    parser.add_argument("--lamb", type=float, default=0.0005, help="L2 reg.")
xnouhz's avatar
xnouhz committed
145
    # model params
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    parser.add_argument(
        "--hid-dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-layers", type=int, default=5, help="Number of GCN layers."
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="cat",
        help="Type of aggregation.",
        choices=["cat", "max", "lstm"],
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.5,
        help="Dropout applied at all layers.",
    )
xnouhz's avatar
xnouhz committed
165
166
167
168
169
170
171
172
173
174
175

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(args.run):
        acc_lists.append(main(args))

    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
    std = np.around(np.std(acc_lists, axis=0), decimals=3)
176
177
178
    print("total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)