main.py 5.06 KB
Newer Older
xnouhz's avatar
xnouhz committed
1
2
3
4
""" The main file to train a JKNet model using a full graph """

import argparse
import copy
5
6

import numpy as np
xnouhz's avatar
xnouhz committed
7
8
import torch
import torch.nn as nn
9
import torch.optim as optim
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11

from dgl.data import CiteseerGraphDataset, CoraGraphDataset
xnouhz's avatar
xnouhz committed
12
from model import JKNet
13
14
15
from sklearn.model_selection import train_test_split
from tqdm import trange

xnouhz's avatar
xnouhz committed
16
17
18
19

def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
20
    if args.dataset == "Cora":
xnouhz's avatar
xnouhz committed
21
        dataset = CoraGraphDataset()
22
    elif args.dataset == "Citeseer":
xnouhz's avatar
xnouhz committed
23
24
        dataset = CiteseerGraphDataset()
    else:
25
26
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

xnouhz's avatar
xnouhz committed
27
28
29
    graph = dataset[0]

    # check cuda
30
31
32
33
34
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )
xnouhz's avatar
xnouhz committed
35
36
37
38
39

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
40
    labels = graph.ndata.pop("label").to(device).long()
xnouhz's avatar
xnouhz committed
41
42

    # Extract node features
43
    feats = graph.ndata.pop("feat").to(device)
xnouhz's avatar
xnouhz committed
44
45
46
47
48
49
50
51
52
53
54
55
    n_features = feats.shape[-1]

    # create masks for train / validation / test
    # train : val : test = 6 : 2 : 2
    n_nodes = graph.num_nodes()
    idx = torch.arange(n_nodes).to(device)
    train_idx, test_idx = train_test_split(idx, test_size=0.2)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.25)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
56
57
58
59
60
61
62
63
64
    model = JKNet(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_layers=args.num_layers,
        mode=args.mode,
        dropout=args.dropout,
    ).to(device)

xnouhz's avatar
xnouhz committed
65
66
67
68
69
70
71
72
    best_model = copy.deepcopy(model)

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)

    # Step 4: training epochs =============================================================== #
    acc = 0
73
    epochs = trange(args.epochs, desc="Accuracy & Loss")
xnouhz's avatar
xnouhz committed
74
75
76
77
78
79
80
81
82

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
83
84
85
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
xnouhz's avatar
xnouhz committed
86
87
88
89
90
91
92
93
94
95
96

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
97
98
99
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
xnouhz's avatar
xnouhz committed
100
101

        # Print out performance
102
103
104
105
106
107
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

xnouhz's avatar
xnouhz committed
108
109
110
111
112
113
        if valid_acc > acc:
            acc = valid_acc
            best_model = copy.deepcopy(model)

    best_model.eval()
    logits = best_model(graph, feats)
114
115
116
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
xnouhz's avatar
xnouhz committed
117
118
119
120

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc

121

xnouhz's avatar
xnouhz committed
122
123
124
125
if __name__ == "__main__":
    """
    JKNet Hyperparameters
    """
126
    parser = argparse.ArgumentParser(description="JKNet")
xnouhz's avatar
xnouhz committed
127
128

    # data source params
129
130
131
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
xnouhz's avatar
xnouhz committed
132
    # cuda params
133
134
135
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
xnouhz's avatar
xnouhz committed
136
    # training params
137
138
139
140
141
142
143
144
    parser.add_argument("--run", type=int, default=10, help="Running times.")
    parser.add_argument(
        "--epochs", type=int, default=500, help="Training epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=0.005, help="Learning rate."
    )
    parser.add_argument("--lamb", type=float, default=0.0005, help="L2 reg.")
xnouhz's avatar
xnouhz committed
145
    # model params
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    parser.add_argument(
        "--hid-dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-layers", type=int, default=5, help="Number of GCN layers."
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="cat",
        help="Type of aggregation.",
        choices=["cat", "max", "lstm"],
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.5,
        help="Dropout applied at all layers.",
    )
xnouhz's avatar
xnouhz committed
165
166
167
168
169
170
171
172
173
174
175

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(args.run):
        acc_lists.append(main(args))

    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
    std = np.around(np.std(acc_lists, axis=0), decimals=3)
176
177
178
    print("total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)