train_sampling.py 6.97 KB
Newer Older
lt610's avatar
lt610 committed
1
2
3
4
5
import argparse
import os
import time
import torch
import torch.nn.functional as F
K's avatar
K committed
6
from torch.utils.data import DataLoader
lt610's avatar
lt610 committed
7
from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler
K's avatar
K committed
8
from config import CONFIG
lt610's avatar
lt610 committed
9
from modules import GCNNet
K's avatar
K committed
10
11
from utils import Logger, evaluate, save_log_dir, load_data, calc_f1
import warnings
lt610's avatar
lt610 committed
12

K's avatar
K committed
13
14
15
def main(args, task):
    warnings.filterwarnings('ignore')
    multilabel_data = {'ppi', 'yelp', 'amazon'}
lt610's avatar
lt610 committed
16
17
    multilabel = args.dataset in multilabel_data

K's avatar
K committed
18
19
20
21
22
23
24
25
26
27
    # This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly
    # shifted to one gpu. So we need to
    # 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase
    # 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase
    # We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu
    if args.dataset in ['amazon']:
        cpu_flag = True
    else:
        cpu_flag = False

lt610's avatar
lt610 committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    # load and preprocess dataset
    data = load_data(args, multilabel)
    g = data.g
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    labels = g.ndata['label']

    train_nid = data.train_nid

    in_feats = g.ndata['feat'].shape[1]
    n_classes = data.num_classes
    n_nodes = g.num_nodes()
    n_edges = g.num_edges()

    n_train_samples = train_mask.int().sum().item()
    n_val_samples = val_mask.int().sum().item()
    n_test_samples = test_mask.int().sum().item()

    print("""----Data statistics------'
    #Nodes %d
    #Edges %d
    #Classes/Labels (multi binary labels) %d
    #Train samples %d
    #Val samples %d
    #Test samples %d""" %
          (n_nodes, n_edges, n_classes,
           n_train_samples,
           n_val_samples,
           n_test_samples))
    # load sampler
K's avatar
K committed
59
60
61
62
63
64
65

    kwargs = {
        'dn': args.dataset, 'g': g, 'train_nid': train_nid, 'num_workers_sampler': args.num_workers_sampler,
        'num_subg_sampler': args.num_subg_sampler, 'batch_size_sampler': args.batch_size_sampler,
        'online': args.online, 'num_subg': args.num_subg, 'full': args.full
    }

lt610's avatar
lt610 committed
66
    if args.sampler == "node":
K's avatar
K committed
67
        saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs)
lt610's avatar
lt610 committed
68
    elif args.sampler == "edge":
K's avatar
K committed
69
        saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)
lt610's avatar
lt610 committed
70
    elif args.sampler == "rw":
K's avatar
K committed
71
72
73
74
75
        saint_sampler = SAINTRandomWalkSampler(args.num_roots, args.length, **kwargs)
    else:
        raise NotImplementedError
    loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1,
                        shuffle=True, num_workers=args.num_workers, drop_last=False)
lt610's avatar
lt610 committed
76
77
78
79
80
81
82
83
    # set device for dataset tensors
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
K's avatar
K committed
84
85
        if not cpu_flag:
            g = g.to('cuda:{}'.format(args.gpu))
lt610's avatar
lt610 committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    print('labels shape:', g.ndata['label'].shape)
    print("features shape:", g.ndata['feat'].shape)

    model = GCNNet(
        in_dim=in_feats,
        hid_dim=args.n_hidden,
        out_dim=n_classes,
        arch=args.arch,
        dropout=args.dropout,
        batch_norm=not args.no_batch_norm,
        aggr=args.aggr
    )

    if cuda:
        model.cuda()

    # logger and so on
    log_dir = save_log_dir(args)
    logger = Logger(os.path.join(log_dir, 'loggings'))
    logger.write(args)

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr)

    # set train_nids to cuda tensor
    if cuda:
        train_nid = torch.from_numpy(train_nid).cuda()
        print("GPU memory allocated before training(MB)",
              torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
    start_time = time.time()
    best_f1 = -1

    for epoch in range(args.n_epochs):
K's avatar
K committed
121
        for j, subg in enumerate(loader):
lt610's avatar
lt610 committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            if cuda:
                subg = subg.to(torch.cuda.current_device())
            model.train()
            # forward
            pred = model(subg)
            batch_labels = subg.ndata['label']

            if multilabel:
                loss = F.binary_cross_entropy_with_logits(pred, batch_labels, reduction='sum',
                                                          weight=subg.ndata['l_n'].unsqueeze(1))
            else:
                loss = F.cross_entropy(pred, batch_labels, reduction='none')
                loss = (subg.ndata['l_n'] * loss).sum()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), 5)
            optimizer.step()

K's avatar
K committed
141
142
143
144
145
146
147
148
            if j == len(loader) - 1:
                model.eval()
                with torch.no_grad():
                    train_f1_mic, train_f1_mac = calc_f1(batch_labels.cpu().numpy(),
                                                         pred.cpu().numpy(), multilabel)
                    print(f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/"
                          f"{len(loader)}:training loss", loss.item())
                    print("Train F1-mic {:.4f}, Train F1-mac {:.4f}".format(train_f1_mic, train_f1_mac))
lt610's avatar
lt610 committed
149
        # evaluate
K's avatar
K committed
150
        model.eval()
lt610's avatar
lt610 committed
151
        if epoch % args.val_every == 0:
K's avatar
K committed
152
153
            if cpu_flag and cuda:  # Only when we have shifted model to gpu and we need to shift it back on cpu
                model = model.to('cpu')
lt610's avatar
lt610 committed
154
155
156
157
158
159
160
161
            val_f1_mic, val_f1_mac = evaluate(
                model, g, labels, val_mask, multilabel)
            print(
                "Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(val_f1_mic, val_f1_mac))
            if val_f1_mic > best_f1:
                best_f1 = val_f1_mic
                print('new best val f1:', best_f1)
                torch.save(model.state_dict(), os.path.join(
K's avatar
K committed
162
163
164
                    log_dir, 'best_model_{}.pkl'.format(task)))
            if cpu_flag and cuda:
                model.cuda()
lt610's avatar
lt610 committed
165
166
167
168
169
170
171

    end_time = time.time()
    print(f'training using time {end_time - start_time}')

    # test
    if args.use_val:
        model.load_state_dict(torch.load(os.path.join(
K's avatar
K committed
172
173
174
            log_dir, 'best_model_{}.pkl'.format(task))))
    if cpu_flag and cuda:
        model = model.to('cpu')
lt610's avatar
lt610 committed
175
176
177
178
179
    test_f1_mic, test_f1_mac = evaluate(
        model, g, labels, test_mask, multilabel)
    print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac))

if __name__ == '__main__':
K's avatar
K committed
180
    warnings.filterwarnings('ignore')
lt610's avatar
lt610 committed
181

K's avatar
K committed
182
183
184
185
186
187
188
189
    parser = argparse.ArgumentParser(description='GraphSAINT')
    parser.add_argument("--task", type=str, default="ppi_n", help="type of tasks")
    parser.add_argument("--online", dest='online', action='store_true', help="sampling method in training phase")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    task = parser.parse_args().task
    args = argparse.Namespace(**CONFIG[task])
    args.online = parser.parse_args().online
    args.gpu = parser.parse_args().gpu
lt610's avatar
lt610 committed
190
191
    print(args)

K's avatar
K committed
192
    main(args, task=task)