cluster_gcn.py 8.24 KB
Newer Older
1
2
3
4
5
6
import argparse
import os
import time
import random

import numpy as np
7
import networkx as nx
8
9
10
11
import sklearn.preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
12
import dgl
13
import dgl.function as fn
14
15
from dgl.data import register_data_args

16
from modules import GraphSAGE
17
18
19
20
21
22
23
24
25
26
27
from sampler import ClusterIter
from utils import Logger, evaluate, save_log_dir, load_data


def main(args):
    torch.manual_seed(args.rnd_seed)
    np.random.seed(args.rnd_seed)
    random.seed(args.rnd_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

28
    multitask_data = set(['ppi'])
29
30
31
32
    multitask = args.dataset in multitask_data

    # load and preprocess dataset
    data = load_data(args)
Xiangkun Hu's avatar
Xiangkun Hu committed
33
34
35
36
37
    g = data.g
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    labels = g.ndata['label']
38

Xiangkun Hu's avatar
Xiangkun Hu committed
39
    train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64)
40
41
42

    # Normalize features
    if args.normalize:
Xiangkun Hu's avatar
Xiangkun Hu committed
43
44
        feats = g.ndata['feat']
        train_feats = feats[train_mask]
45
        scaler = sklearn.preprocessing.StandardScaler()
Xiangkun Hu's avatar
Xiangkun Hu committed
46
47
48
        scaler.fit(train_feats.data.numpy())
        features = scaler.transform(feats.data.numpy())
        g.ndata['feat'] = torch.FloatTensor(features)
49

Xiangkun Hu's avatar
Xiangkun Hu committed
50
51
52
    in_feats = g.ndata['feat'].shape[1]
    n_classes = data.num_classes
    n_edges = g.number_of_edges()
53

Zihao Ye's avatar
Zihao Ye committed
54
55
56
    n_train_samples = train_mask.int().sum().item()
    n_val_samples = val_mask.int().sum().item()
    n_test_samples = test_mask.int().sum().item()
57
58
59
60
61
62
63
64
65
66
67
68
69

    print("""----Data statistics------'
    #Edges %d
    #Classes %d
    #Train samples %d
    #Val samples %d
    #Test samples %d""" %
            (n_edges, n_classes,
            n_train_samples,
            n_val_samples,
            n_test_samples))
    # create GCN model
    if args.self_loop and not args.dataset.startswith('reddit'):
70
71
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
72
        print("adding self-loop edges")
73
74
75
    # metis only support int64 graph
    g = g.long()

76
77
78
79
80
81
82
83
84
85
86
    if args.use_pp:
        g.update_all(fn.copy_u('feat', 'm'), fn.sum('m', 'feat_agg'))
        g.ndata['feat'] = torch.cat([g.ndata['feat'], g.ndata['feat_agg']], 1)
        del g.ndata['feat_agg']

    cluster_iterator = dgl.dataloading.GraphDataLoader(
        dgl.dataloading.ClusterGCNSubgraphIterator(
            dgl.node_subgraph(g, train_nid), args.psize, './cache'),
        batch_size=args.batch_size, num_workers=4)
    #cluster_iterator = ClusterIter(
    #    args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)
87
88
89
90
91
92
93
94
95

    # set device for dataset tensors
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
96
        g = g.int().to(args.gpu)
97

Xiangkun Hu's avatar
Xiangkun Hu committed
98
99
    print('labels shape:', g.ndata['label'].shape)
    print("features shape, ", g.ndata['feat'].shape)
100

101
102
103
104
105
106
107
    model = GraphSAGE(in_feats,
                      args.n_hidden,
                      n_classes,
                      args.n_layers,
                      F.relu,
                      args.dropout,
                      args.use_pp)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    if cuda:
        model.cuda()

    # logger and so on
    log_dir = save_log_dir(args)
    logger = Logger(os.path.join(log_dir, 'loggings'))
    logger.write(args)

    # Loss function
    if multitask:
        print('Using multi-label loss')
        loss_f = nn.BCEWithLogitsLoss()
    else:
        print('Using multi-class loss')
        loss_f = nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # set train_nids to cuda tensor
    if cuda:
        train_nid = torch.from_numpy(train_nid).cuda()
Xiangkun Hu's avatar
Xiangkun Hu committed
133
134
        print("current memory after model before training",
              torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
135
136
137
138
139
140
    start_time = time.time()
    best_f1 = -1

    for epoch in range(args.n_epochs):
        for j, cluster in enumerate(cluster_iterator):
            # sync with upper level training graph
Xiangkun Hu's avatar
Xiangkun Hu committed
141
142
            if cuda:
                cluster = cluster.to(torch.cuda.current_device())
143
144
            model.train()
            # forward
Xiangkun Hu's avatar
Xiangkun Hu committed
145
            batch_labels = cluster.ndata['label']
146
            batch_train_mask = cluster.ndata['train_mask']
147
148
149
            if batch_train_mask.sum().item() == 0:
                continue
            pred = model(cluster)
150
151
152
153
154
155
156
157
158
            loss = loss_f(pred[batch_train_mask],
                          batch_labels[batch_train_mask])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # in PPI case, `log_every` is chosen to log one time per epoch. 
            # Choose your log freq dynamically when you want more info within one epoch
            if j % args.log_every == 0:
159
160
                print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/"
                      f"{len(cluster_iterator)}:training loss", loss.item())
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        print("current memory:",
              torch.cuda.memory_allocated(device=pred.device) / 1024 / 1024)

        # evaluate
        if epoch % args.val_every == 0:
            val_f1_mic, val_f1_mac = evaluate(
                model, g, labels, val_mask, multitask)
            print(
                "Val F1-mic{:.4f}, Val F1-mac{:.4f}". format(val_f1_mic, val_f1_mac))
            if val_f1_mic > best_f1:
                best_f1 = val_f1_mic
                print('new best val f1:', best_f1)
                torch.save(model.state_dict(), os.path.join(
                    log_dir, 'best_model.pkl'))

    end_time = time.time()
    print(f'training using time {start_time-end_time}')

    # test
    if args.use_val:
        model.load_state_dict(torch.load(os.path.join(
            log_dir, 'best_model.pkl')))
    test_f1_mic, test_f1_mac = evaluate(
        model, g, labels, test_mask, multitask)
185
    print("Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
186
187
188
189
190
191
192
193
194
195
196
197
198

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--lr", type=float, default=3e-2,
                        help="learning rate")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--log-every", type=int, default=100,
chwan-rice's avatar
chwan-rice committed
199
                        help="the frequency to save model")
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    parser.add_argument("--batch-size", type=int, default=20,
                        help="batch size")
    parser.add_argument("--psize", type=int, default=1500,
                        help="partition number")
    parser.add_argument("--test-batch-size", type=int, default=1000,
                        help="test batch size")
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--val-every", type=int, default=1,
                        help="number of epoch of doing inference on validation")
    parser.add_argument("--rnd-seed", type=int, default=3,
                        help="number of epoch of doing inference on validation")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--use-pp", action='store_true',
chwan-rice's avatar
chwan-rice committed
217
                        help="whether to use precomputation")
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    parser.add_argument("--normalize", action='store_true',
                        help="whether to use normalized feature")
    parser.add_argument("--use-val", action='store_true',
                        help="whether to use validated best model to test")
    parser.add_argument("--weight-decay", type=float, default=5e-4,
                        help="Weight for L2 loss")
    parser.add_argument("--note", type=str, default='none',
                        help="note for log dir")

    args = parser.parse_args()

    print(args)

    main(args)