train_sampling.py 7.58 KB
Newer Older
lt610's avatar
lt610 committed
1
2
3
import argparse
import os
import time
4
5
import warnings

lt610's avatar
lt610 committed
6
7
import torch
import torch.nn.functional as F
K's avatar
K committed
8
from config import CONFIG
lt610's avatar
lt610 committed
9
from modules import GCNNet
10
11
12
13
from sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler
from torch.utils.data import DataLoader
from utils import Logger, calc_f1, evaluate, load_data, save_log_dir

lt610's avatar
lt610 committed
14

K's avatar
K committed
15
def main(args, task):
16
17
    warnings.filterwarnings("ignore")
    multilabel_data = {"ppi", "yelp", "amazon"}
lt610's avatar
lt610 committed
18
19
    multilabel = args.dataset in multilabel_data

K's avatar
K committed
20
21
22
23
24
    # This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly
    # shifted to one gpu. So we need to
    # 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase
    # 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase
    # We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu
25
    if args.dataset in ["amazon"]:
K's avatar
K committed
26
27
28
29
        cpu_flag = True
    else:
        cpu_flag = False

lt610's avatar
lt610 committed
30
31
32
    # load and preprocess dataset
    data = load_data(args, multilabel)
    g = data.g
33
34
35
36
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    labels = g.ndata["label"]
lt610's avatar
lt610 committed
37
38
39

    train_nid = data.train_nid

40
    in_feats = g.ndata["feat"].shape[1]
lt610's avatar
lt610 committed
41
42
43
44
45
46
47
48
    n_classes = data.num_classes
    n_nodes = g.num_nodes()
    n_edges = g.num_edges()

    n_train_samples = train_mask.int().sum().item()
    n_val_samples = val_mask.int().sum().item()
    n_test_samples = test_mask.int().sum().item()

49
50
    print(
        """----Data statistics------'
lt610's avatar
lt610 committed
51
52
53
54
55
    #Nodes %d
    #Edges %d
    #Classes/Labels (multi binary labels) %d
    #Train samples %d
    #Val samples %d
56
57
58
59
60
61
62
63
64
65
    #Test samples %d"""
        % (
            n_nodes,
            n_edges,
            n_classes,
            n_train_samples,
            n_val_samples,
            n_test_samples,
        )
    )
lt610's avatar
lt610 committed
66
    # load sampler
K's avatar
K committed
67
68

    kwargs = {
69
70
71
72
73
74
75
76
77
        "dn": args.dataset,
        "g": g,
        "train_nid": train_nid,
        "num_workers_sampler": args.num_workers_sampler,
        "num_subg_sampler": args.num_subg_sampler,
        "batch_size_sampler": args.batch_size_sampler,
        "online": args.online,
        "num_subg": args.num_subg,
        "full": args.full,
K's avatar
K committed
78
79
    }

lt610's avatar
lt610 committed
80
    if args.sampler == "node":
K's avatar
K committed
81
        saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs)
lt610's avatar
lt610 committed
82
    elif args.sampler == "edge":
K's avatar
K committed
83
        saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)
lt610's avatar
lt610 committed
84
    elif args.sampler == "rw":
85
86
87
        saint_sampler = SAINTRandomWalkSampler(
            args.num_roots, args.length, **kwargs
        )
K's avatar
K committed
88
89
    else:
        raise NotImplementedError
90
91
92
93
94
95
96
97
    loader = DataLoader(
        saint_sampler,
        collate_fn=saint_sampler.__collate_fn__,
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=False,
    )
lt610's avatar
lt610 committed
98
99
100
101
102
103
104
105
    # set device for dataset tensors
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
K's avatar
K committed
106
        if not cpu_flag:
107
            g = g.to("cuda:{}".format(args.gpu))
lt610's avatar
lt610 committed
108

109
110
    print("labels shape:", g.ndata["label"].shape)
    print("features shape:", g.ndata["feat"].shape)
lt610's avatar
lt610 committed
111
112
113
114
115
116
117
118

    model = GCNNet(
        in_dim=in_feats,
        hid_dim=args.n_hidden,
        out_dim=n_classes,
        arch=args.arch,
        dropout=args.dropout,
        batch_norm=not args.no_batch_norm,
119
        aggr=args.aggr,
lt610's avatar
lt610 committed
120
121
122
123
124
125
126
    )

    if cuda:
        model.cuda()

    # logger and so on
    log_dir = save_log_dir(args)
127
    logger = Logger(os.path.join(log_dir, "loggings"))
lt610's avatar
lt610 committed
128
129
130
    logger.write(args)

    # use optimizer
131
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lt610's avatar
lt610 committed
132
133
134
135

    # set train_nids to cuda tensor
    if cuda:
        train_nid = torch.from_numpy(train_nid).cuda()
136
137
138
139
        print(
            "GPU memory allocated before training(MB)",
            torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024,
        )
lt610's avatar
lt610 committed
140
141
142
143
    start_time = time.time()
    best_f1 = -1

    for epoch in range(args.n_epochs):
K's avatar
K committed
144
        for j, subg in enumerate(loader):
lt610's avatar
lt610 committed
145
146
147
148
149
            if cuda:
                subg = subg.to(torch.cuda.current_device())
            model.train()
            # forward
            pred = model(subg)
150
            batch_labels = subg.ndata["label"]
lt610's avatar
lt610 committed
151
152

            if multilabel:
153
154
155
156
157
158
                loss = F.binary_cross_entropy_with_logits(
                    pred,
                    batch_labels,
                    reduction="sum",
                    weight=subg.ndata["l_n"].unsqueeze(1),
                )
lt610's avatar
lt610 committed
159
            else:
160
161
                loss = F.cross_entropy(pred, batch_labels, reduction="none")
                loss = (subg.ndata["l_n"] * loss).sum()
lt610's avatar
lt610 committed
162
163
164
165
166
167

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), 5)
            optimizer.step()

K's avatar
K committed
168
169
170
            if j == len(loader) - 1:
                model.eval()
                with torch.no_grad():
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                    train_f1_mic, train_f1_mac = calc_f1(
                        batch_labels.cpu().numpy(),
                        pred.cpu().numpy(),
                        multilabel,
                    )
                    print(
                        f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/"
                        f"{len(loader)}:training loss",
                        loss.item(),
                    )
                    print(
                        "Train F1-mic {:.4f}, Train F1-mac {:.4f}".format(
                            train_f1_mic, train_f1_mac
                        )
                    )
lt610's avatar
lt610 committed
186
        # evaluate
K's avatar
K committed
187
        model.eval()
lt610's avatar
lt610 committed
188
        if epoch % args.val_every == 0:
189
190
191
192
            if (
                cpu_flag and cuda
            ):  # Only when we have shifted model to gpu and we need to shift it back on cpu
                model = model.to("cpu")
lt610's avatar
lt610 committed
193
            val_f1_mic, val_f1_mac = evaluate(
194
195
                model, g, labels, val_mask, multilabel
            )
lt610's avatar
lt610 committed
196
            print(
197
198
199
200
                "Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(
                    val_f1_mic, val_f1_mac
                )
            )
lt610's avatar
lt610 committed
201
202
            if val_f1_mic > best_f1:
                best_f1 = val_f1_mic
203
204
205
206
207
                print("new best val f1:", best_f1)
                torch.save(
                    model.state_dict(),
                    os.path.join(log_dir, "best_model_{}.pkl".format(task)),
                )
K's avatar
K committed
208
209
            if cpu_flag and cuda:
                model.cuda()
lt610's avatar
lt610 committed
210
211

    end_time = time.time()
212
    print(f"training using time {end_time - start_time}")
lt610's avatar
lt610 committed
213
214
215

    # test
    if args.use_val:
216
217
218
        model.load_state_dict(
            torch.load(os.path.join(log_dir, "best_model_{}.pkl".format(task)))
        )
K's avatar
K committed
219
    if cpu_flag and cuda:
220
221
222
223
224
225
226
        model = model.to("cpu")
    test_f1_mic, test_f1_mac = evaluate(model, g, labels, test_mask, multilabel)
    print(
        "Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(
            test_f1_mic, test_f1_mac
        )
    )
lt610's avatar
lt610 committed
227
228


229
230
231
232
233
234
235
236
237
238
239
240
241
if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser(description="GraphSAINT")
    parser.add_argument(
        "--task", type=str, default="ppi_n", help="type of tasks"
    )
    parser.add_argument(
        "--online",
        dest="online",
        action="store_true",
        help="sampling method in training phase",
    )
K's avatar
K committed
242
243
244
245
246
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    task = parser.parse_args().task
    args = argparse.Namespace(**CONFIG[task])
    args.online = parser.parse_args().online
    args.gpu = parser.parse_args().gpu
lt610's avatar
lt610 committed
247
248
    print(args)

K's avatar
K committed
249
    main(args, task=task)