main.py 7.73 KB
Newer Older
1
2
3
4
5
import argparse
import json
import logging
import os
from time import time
6

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
import dgl

9
10
11
import torch
import torch.nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
13
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
14
from network import get_sag_network
15
from torch.utils.data import random_split
16
17
18
19
20
from utils import get_stats


def parse_args():
    parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling")
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    parser.add_argument(
        "--dataset",
        type=str,
        default="DD",
        choices=["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"],
        help="DD/PROTEINS/NCI1/NCI109/Mutagenicity",
    )
    parser.add_argument(
        "--batch_size", type=int, default=128, help="batch size"
    )
    parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
    parser.add_argument(
        "--weight_decay", type=float, default=1e-4, help="weight decay"
    )
    parser.add_argument(
        "--pool_ratio", type=float, default=0.5, help="pooling ratio"
    )
    parser.add_argument("--hid_dim", type=int, default=128, help="hidden size")
    parser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout ratio"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100000,
        help="max number of training epochs",
    )
    parser.add_argument(
        "--patience", type=int, default=50, help="patience for early stopping"
    )
    parser.add_argument(
        "--device", type=int, default=-1, help="device id, -1 for cpu"
    )
    parser.add_argument(
        "--architecture",
        type=str,
        default="hierarchical",
        choices=["hierarchical", "global"],
        help="model architecture",
    )
    parser.add_argument(
        "--dataset_path", type=str, default="./dataset", help="path to dataset"
    )
    parser.add_argument(
        "--conv_layers", type=int, default=3, help="number of conv layers"
    )
    parser.add_argument(
        "--print_every",
        type=int,
        default=10,
        help="print trainlog every k epochs, -1 for silent training",
    )
    parser.add_argument(
        "--num_trials", type=int, default=1, help="number of trials"
    )
76
    parser.add_argument("--output_path", type=str, default="./output")
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    args = parser.parse_args()

    # device
    args.device = "cpu" if args.device == -1 else "cuda:{}".format(args.device)
    if not torch.cuda.is_available():
        logging.warning("CUDA is not available, use CPU for training.")
        args.device = "cpu"

    # print every
    if args.print_every == -1:
        args.print_every = args.epochs + 1

    # paths
    if not os.path.exists(args.dataset_path):
        os.makedirs(args.dataset_path)
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)
    name = "Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log".format(
96
97
98
99
100
101
102
        args.dataset,
        args.hid_dim,
        args.architecture,
        args.pool_ratio,
        args.weight_decay,
        args.lr,
    )
103
104
105
106
107
    args.output_path = os.path.join(args.output_path, name)

    return args


108
def train(model: torch.nn.Module, optimizer, trainloader, device):
109
    model.train()
110
    total_loss = 0.0
111
    num_batches = len(trainloader)
112
113
114
115
    for batch in trainloader:
        optimizer.zero_grad()
        batch_graphs, batch_labels = batch
        batch_graphs = batch_graphs.to(device)
116
        batch_labels = batch_labels.long().to(device)
117
118
119
120
121
122
        out = model(batch_graphs)
        loss = F.nll_loss(out, batch_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
123

124
    return total_loss / num_batches
125
126
127


@torch.no_grad()
128
def test(model: torch.nn.Module, loader, device):
129
    model.eval()
130
131
    correct = 0.0
    loss = 0.0
132
    num_graphs = 0
133
134
    for batch in loader:
        batch_graphs, batch_labels = batch
135
        num_graphs += batch_labels.size(0)
136
        batch_graphs = batch_graphs.to(device)
137
        batch_labels = batch_labels.long().to(device)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        out = model(batch_graphs)
        pred = out.argmax(dim=1)
        loss += F.nll_loss(out, batch_labels, reduction="sum").item()
        correct += pred.eq(batch_labels).sum().item()
    return correct / num_graphs, loss / num_graphs


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)

    # add self loop. We add self loop for each graph here since the function "add_self_loop" does not
    # support batch graph.
    for i in range(len(dataset)):
        dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])

    num_training = int(len(dataset) * 0.8)
    num_val = int(len(dataset) * 0.1)
    num_test = len(dataset) - num_val - num_training
157
158
159
160
161
162
163
164
165
166
167
168
169
    train_set, val_set, test_set = random_split(
        dataset, [num_training, num_val, num_test]
    )

    train_loader = GraphDataLoader(
        train_set, batch_size=args.batch_size, shuffle=True, num_workers=6
    )
    val_loader = GraphDataLoader(
        val_set, batch_size=args.batch_size, num_workers=2
    )
    test_loader = GraphDataLoader(
        test_set, batch_size=args.batch_size, num_workers=2
    )
170
171

    device = torch.device(args.device)
172

173
174
175
    # Step 2: Create model =================================================================== #
    num_feature, num_classes, _ = dataset.statistics()
    model_op = get_sag_network(args.architecture)
176
177
178
179
180
181
182
183
    model = model_op(
        in_dim=num_feature,
        hid_dim=args.hid_dim,
        out_dim=num_classes,
        num_convs=args.conv_layers,
        pool_ratio=args.pool_ratio,
        dropout=args.dropout,
    ).to(device)
184
185
186
187
    args.num_feature = int(num_feature)
    args.num_classes = int(num_classes)

    # Step 3: Create training components ===================================================== #
188
189
190
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
191
192
193
194

    # Step 4: training epoches =============================================================== #
    bad_cound = 0
    best_val_loss = float("inf")
195
    final_test_acc = 0.0
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    best_epoch = 0
    train_times = []
    for e in range(args.epochs):
        s_time = time()
        train_loss = train(model, optimizer, train_loader, device)
        train_times.append(time() - s_time)
        val_acc, val_loss = test(model, val_loader, device)
        test_acc, _ = test(model, test_loader, device)
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            final_test_acc = test_acc
            bad_cound = 0
            best_epoch = e + 1
        else:
            bad_cound += 1
        if bad_cound >= args.patience:
            break
213

214
        if (e + 1) % args.print_every == 0:
215
216
217
            log_format = (
                "Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}"
            )
218
            print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))
219
220
221
222
223
    print(
        "Best Epoch {}, final test acc {:.4f}".format(
            best_epoch, final_test_acc
        )
    )
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    return final_test_acc, sum(train_times) / len(train_times)


if __name__ == "__main__":
    args = parse_args()
    res = []
    train_times = []
    for i in range(args.num_trials):
        print("Trial {}/{}".format(i + 1, args.num_trials))
        acc, train_time = main(args)
        res.append(acc)
        train_times.append(train_time)

    mean, err_bd = get_stats(res)
    print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))

240
241
242
243
244
    out_dict = {
        "hyper-parameters": vars(args),
        "result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
        "train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
    }
245
246
247

    with open(args.output_path, "w") as f:
        json.dump(out_dict, f, sort_keys=True, indent=4)