cluster_gcn.py 7.72 KB
Newer Older
1
2
3
4
5
6
import argparse
import os
import time
import random

import numpy as np
7
import networkx as nx
8
9
10
11
import sklearn.preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
12
import dgl
13
14
from dgl.data import register_data_args

15
from modules import GraphSAGE
16
17
18
19
20
21
22
23
24
25
26
from sampler import ClusterIter
from utils import Logger, evaluate, save_log_dir, load_data


def main(args):
    torch.manual_seed(args.rnd_seed)
    np.random.seed(args.rnd_seed)
    random.seed(args.rnd_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

27
    multitask_data = set(['ppi'])
28
29
30
31
    multitask = args.dataset in multitask_data

    # load and preprocess dataset
    data = load_data(args)
Xiangkun Hu's avatar
Xiangkun Hu committed
32
33
34
35
36
    g = data.g
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    labels = g.ndata['label']
37

Xiangkun Hu's avatar
Xiangkun Hu committed
38
    train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64)
39
40
41

    # Normalize features
    if args.normalize:
Xiangkun Hu's avatar
Xiangkun Hu committed
42
43
        feats = g.ndata['feat']
        train_feats = feats[train_mask]
44
        scaler = sklearn.preprocessing.StandardScaler()
Xiangkun Hu's avatar
Xiangkun Hu committed
45
46
47
        scaler.fit(train_feats.data.numpy())
        features = scaler.transform(feats.data.numpy())
        g.ndata['feat'] = torch.FloatTensor(features)
48

Xiangkun Hu's avatar
Xiangkun Hu committed
49
50
51
    in_feats = g.ndata['feat'].shape[1]
    n_classes = data.num_classes
    n_edges = g.number_of_edges()
52

Zihao Ye's avatar
Zihao Ye committed
53
54
55
    n_train_samples = train_mask.int().sum().item()
    n_val_samples = val_mask.int().sum().item()
    n_test_samples = test_mask.int().sum().item()
56
57
58
59
60
61
62
63
64
65
66
67
68

    print("""----Data statistics------'
    #Edges %d
    #Classes %d
    #Train samples %d
    #Val samples %d
    #Test samples %d""" %
            (n_edges, n_classes,
            n_train_samples,
            n_val_samples,
            n_test_samples))
    # create GCN model
    if args.self_loop and not args.dataset.startswith('reddit'):
69
70
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
71
        print("adding self-loop edges")
72
73
74
75
76
    # metis only support int64 graph
    g = g.long()

    cluster_iterator = ClusterIter(
        args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)
77
78
79
80
81
82
83
84
85

    # set device for dataset tensors
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
86
        g = g.int().to(args.gpu)
87

Xiangkun Hu's avatar
Xiangkun Hu committed
88
89
    print('labels shape:', g.ndata['label'].shape)
    print("features shape, ", g.ndata['feat'].shape)
90

91
92
93
94
95
96
97
    model = GraphSAGE(in_feats,
                      args.n_hidden,
                      n_classes,
                      args.n_layers,
                      F.relu,
                      args.dropout,
                      args.use_pp)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    if cuda:
        model.cuda()

    # logger and so on
    log_dir = save_log_dir(args)
    logger = Logger(os.path.join(log_dir, 'loggings'))
    logger.write(args)

    # Loss function
    if multitask:
        print('Using multi-label loss')
        loss_f = nn.BCEWithLogitsLoss()
    else:
        print('Using multi-class loss')
        loss_f = nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # set train_nids to cuda tensor
    if cuda:
        train_nid = torch.from_numpy(train_nid).cuda()
Xiangkun Hu's avatar
Xiangkun Hu committed
123
124
        print("current memory after model before training",
              torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
125
126
127
128
129
130
    start_time = time.time()
    best_f1 = -1

    for epoch in range(args.n_epochs):
        for j, cluster in enumerate(cluster_iterator):
            # sync with upper level training graph
Xiangkun Hu's avatar
Xiangkun Hu committed
131
132
            if cuda:
                cluster = cluster.to(torch.cuda.current_device())
133
134
135
            model.train()
            # forward
            pred = model(cluster)
Xiangkun Hu's avatar
Xiangkun Hu committed
136
            batch_labels = cluster.ndata['label']
137
138
139
140
141
142
143
144
145
146
            batch_train_mask = cluster.ndata['train_mask']
            loss = loss_f(pred[batch_train_mask],
                          batch_labels[batch_train_mask])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # in PPI case, `log_every` is chosen to log one time per epoch. 
            # Choose your log freq dynamically when you want more info within one epoch
            if j % args.log_every == 0:
147
148
                print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/"
                      f"{len(cluster_iterator)}:training loss", loss.item())
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        print("current memory:",
              torch.cuda.memory_allocated(device=pred.device) / 1024 / 1024)

        # evaluate
        if epoch % args.val_every == 0:
            val_f1_mic, val_f1_mac = evaluate(
                model, g, labels, val_mask, multitask)
            print(
                "Val F1-mic{:.4f}, Val F1-mac{:.4f}". format(val_f1_mic, val_f1_mac))
            if val_f1_mic > best_f1:
                best_f1 = val_f1_mic
                print('new best val f1:', best_f1)
                torch.save(model.state_dict(), os.path.join(
                    log_dir, 'best_model.pkl'))

    end_time = time.time()
    print(f'training using time {start_time-end_time}')

    # test
    if args.use_val:
        model.load_state_dict(torch.load(os.path.join(
            log_dir, 'best_model.pkl')))
    test_f1_mic, test_f1_mac = evaluate(
        model, g, labels, test_mask, multitask)
173
    print("Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
174
175
176
177
178
179
180
181
182
183
184
185
186

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--lr", type=float, default=3e-2,
                        help="learning rate")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--log-every", type=int, default=100,
chwan-rice's avatar
chwan-rice committed
187
                        help="the frequency to save model")
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    parser.add_argument("--batch-size", type=int, default=20,
                        help="batch size")
    parser.add_argument("--psize", type=int, default=1500,
                        help="partition number")
    parser.add_argument("--test-batch-size", type=int, default=1000,
                        help="test batch size")
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--val-every", type=int, default=1,
                        help="number of epoch of doing inference on validation")
    parser.add_argument("--rnd-seed", type=int, default=3,
                        help="number of epoch of doing inference on validation")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--use-pp", action='store_true',
chwan-rice's avatar
chwan-rice committed
205
                        help="whether to use precomputation")
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    parser.add_argument("--normalize", action='store_true',
                        help="whether to use normalized feature")
    parser.add_argument("--use-val", action='store_true',
                        help="whether to use validated best model to test")
    parser.add_argument("--weight-decay", type=float, default=5e-4,
                        help="Weight for L2 loss")
    parser.add_argument("--note", type=str, default='none',
                        help="note for log dir")

    args = parser.parse_args()

    print(args)

    main(args)