train_sampling.py 11.1 KB
Newer Older
maqy's avatar
maqy committed
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
"""
HAN mini-batch training by RandomWalkSampler.
note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test,
so we sampled twice as many neighbors during val/test than training.
"""
import argparse
8
9

import numpy
maqy's avatar
maqy committed
10
11
12
import torch
import torch.nn as nn
import torch.nn.functional as F
13
from model_hetero import SemanticAttention
maqy's avatar
maqy committed
14
15
16
17
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from utils import EarlyStopping, set_random_seed

18
19
20
21
import dgl
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler

maqy's avatar
maqy committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

class HANLayer(torch.nn.Module):
    """
    HAN layer.

    Arguments
    ---------
    num_metapath : number of metapath based sub-graph
    in_size : input feature dimension
    out_size : output feature dimension
    layer_num_heads : number of attention heads
    dropout : Dropout probability

    Inputs
    ------
    g : DGLHeteroGraph
        The heterogeneous graph
    h : tensor
        Input features

    Outputs
    -------
    tensor
        The output feature
    """

48
49
50
    def __init__(
        self, num_metapath, in_size, out_size, layer_num_heads, dropout
    ):
maqy's avatar
maqy committed
51
52
53
54
55
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(num_metapath):
56
57
58
59
60
61
62
63
64
65
66
67
68
69
            self.gat_layers.append(
                GATConv(
                    in_size,
                    out_size,
                    layer_num_heads,
                    dropout,
                    dropout,
                    activation=F.elu,
                    allow_zero_in_degree=True,
                )
            )
        self.semantic_attention = SemanticAttention(
            in_size=out_size * layer_num_heads
        )
maqy's avatar
maqy committed
70
71
72
73
74
75
        self.num_metapath = num_metapath

    def forward(self, block_list, h_list):
        semantic_embeddings = []

        for i, block in enumerate(block_list):
76
77
78
79
80
81
            semantic_embeddings.append(
                self.gat_layers[i](block, h_list[i]).flatten(1)
            )
        semantic_embeddings = torch.stack(
            semantic_embeddings, dim=1
        )  # (N, M, D * K)
maqy's avatar
maqy committed
82
83
84
85
86

        return self.semantic_attention(semantic_embeddings)  # (N, D * K)


class HAN(nn.Module):
87
88
89
    def __init__(
        self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout
    ):
maqy's avatar
maqy committed
90
91
92
        super(HAN, self).__init__()

        self.layers = nn.ModuleList()
93
94
95
        self.layers.append(
            HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)
        )
maqy's avatar
maqy committed
96
        for l in range(1, len(num_heads)):
97
98
99
100
101
102
103
104
105
            self.layers.append(
                HANLayer(
                    num_metapath,
                    hidden_size * num_heads[l - 1],
                    hidden_size,
                    num_heads[l],
                    dropout,
                )
            )
maqy's avatar
maqy committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

    def forward(self, g, h):
        for gnn in self.layers:
            h = gnn(g, h)

        return self.predict(h)


class HANSampler(object):
    def __init__(self, g, metapath_list, num_neighbors):
        self.sampler_list = []
        for metapath in metapath_list:
            # note: random walk may get same route(same edge), which will be removed in the sampled graph.
            # So the sampled graph's edges may be less than num_random_walks(num_neighbors).
121
122
123
124
125
126
127
128
129
130
            self.sampler_list.append(
                RandomWalkNeighborSampler(
                    G=g,
                    num_traversals=1,
                    termination_prob=0,
                    num_random_walks=num_neighbors,
                    num_neighbors=num_neighbors,
                    metapath=metapath,
                )
            )
maqy's avatar
maqy committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    def sample_blocks(self, seeds):
        block_list = []
        for sampler in self.sampler_list:
            frontier = sampler(seeds)
            # add self loop
            frontier = dgl.remove_self_loop(frontier)
            frontier.add_edges(torch.tensor(seeds), torch.tensor(seeds))
            block = dgl.to_block(frontier, seeds)
            block_list.append(block)

        return seeds, block_list


def score(logits, labels):
    _, indices = torch.max(logits, dim=1)
    prediction = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()

    accuracy = (prediction == labels).sum() / len(prediction)
151
152
    micro_f1 = f1_score(labels, prediction, average="micro")
    macro_f1 = f1_score(labels, prediction, average="macro")
maqy's avatar
maqy committed
153
154
155
156

    return accuracy, micro_f1, macro_f1


157
158
159
160
161
162
163
164
165
166
167
def evaluate(
    model,
    g,
    metapath_list,
    num_neighbors,
    features,
    labels,
    val_nid,
    loss_fcn,
    batch_size,
):
maqy's avatar
maqy committed
168
169
    model.eval()

170
171
172
    han_valid_sampler = HANSampler(
        g, metapath_list, num_neighbors=num_neighbors * 2
    )
maqy's avatar
maqy committed
173
174
175
176
177
178
    dataloader = DataLoader(
        dataset=val_nid,
        batch_size=batch_size,
        collate_fn=han_valid_sampler.sample_blocks,
        shuffle=False,
        drop_last=False,
179
180
        num_workers=4,
    )
maqy's avatar
maqy committed
181
182
183
184
185
186
    correct = total = 0
    prediction_list = []
    labels_list = []
    with torch.no_grad():
        for step, (seeds, blocks) in enumerate(dataloader):
            h_list = load_subtensors(blocks, features)
187
188
            blocks = [block.to(args["device"]) for block in blocks]
            hs = [h.to(args["device"]) for h in h_list]
maqy's avatar
maqy committed
189
190

            logits = model(blocks, hs)
191
192
193
            loss = loss_fcn(
                logits, labels[numpy.asarray(seeds)].to(args["device"])
            )
maqy's avatar
maqy committed
194
195
196
197
198
199
200
201
202
203
204
205
206
            # get each predict label
            _, indices = torch.max(logits, dim=1)
            prediction = indices.long().cpu().numpy()
            labels_batch = labels[numpy.asarray(seeds)].cpu().numpy()

            prediction_list.append(prediction)
            labels_list.append(labels_batch)

            correct += (prediction == labels_batch).sum()
            total += prediction.shape[0]

    total_prediction = numpy.concatenate(prediction_list)
    total_labels = numpy.concatenate(labels_list)
207
208
    micro_f1 = f1_score(total_labels, total_prediction, average="micro")
    macro_f1 = f1_score(total_labels, total_prediction, average="macro")
maqy's avatar
maqy committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    accuracy = correct / total

    return loss, accuracy, micro_f1, macro_f1


def load_subtensors(blocks, features):
    h_list = []
    for block in blocks:
        input_nodes = block.srcdata[dgl.NID]
        h_list.append(features[input_nodes])
    return h_list


def main(args):
    # acm data
224
    if args["dataset"] == "ACMRaw":
maqy's avatar
maqy committed
225
        from utils import load_data
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        (
            g,
            features,
            labels,
            n_classes,
            train_nid,
            val_nid,
            test_nid,
            train_mask,
            val_mask,
            test_mask,
        ) = load_data("ACMRaw")
        metapath_list = [["pa", "ap"], ["pf", "fp"]]
maqy's avatar
maqy committed
240
    else:
241
242
243
        raise NotImplementedError(
            "Unsupported dataset {}".format(args["dataset"])
        )
maqy's avatar
maqy committed
244
245

    # Is it need to set different neighbors numbers for different meta-path based graph?
246
    num_neighbors = args["num_neighbors"]
maqy's avatar
maqy committed
247
248
249
250
    han_sampler = HANSampler(g, metapath_list, num_neighbors)
    # Create PyTorch DataLoader for constructing blocks
    dataloader = DataLoader(
        dataset=train_nid,
251
        batch_size=args["batch_size"],
maqy's avatar
maqy committed
252
253
254
        collate_fn=han_sampler.sample_blocks,
        shuffle=True,
        drop_last=False,
255
256
257
258
259
260
261
262
263
264
265
        num_workers=4,
    )

    model = HAN(
        num_metapath=len(metapath_list),
        in_size=features.shape[1],
        hidden_size=args["hidden_units"],
        out_size=n_classes,
        num_heads=args["num_heads"],
        dropout=args["dropout"],
    ).to(args["device"])
maqy's avatar
maqy committed
266
267
268

    total_params = sum(p.numel() for p in model.parameters())
    print("total_params: {:d}".format(total_params))
269
270
271
    total_trainable_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
maqy's avatar
maqy committed
272
273
    print("total trainable params: {:d}".format(total_trainable_params))

274
    stopper = EarlyStopping(patience=args["patience"])
maqy's avatar
maqy committed
275
    loss_fn = torch.nn.CrossEntropyLoss()
276
277
278
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
    )
maqy's avatar
maqy committed
279

280
    for epoch in range(args["num_epochs"]):
maqy's avatar
maqy committed
281
282
283
        model.train()
        for step, (seeds, blocks) in enumerate(dataloader):
            h_list = load_subtensors(blocks, features)
284
285
            blocks = [block.to(args["device"]) for block in blocks]
            hs = [h.to(args["device"]) for h in h_list]
maqy's avatar
maqy committed
286
287

            logits = model(blocks, hs)
288
289
290
            loss = loss_fn(
                logits, labels[numpy.asarray(seeds)].to(args["device"])
            )
maqy's avatar
maqy committed
291
292
293
294
295
296

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # print info in each batch
297
298
299
            train_acc, train_micro_f1, train_macro_f1 = score(
                logits, labels[numpy.asarray(seeds)]
            )
maqy's avatar
maqy committed
300
301
302
            print(
                "Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format(
                    epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1
303
304
305
306
307
308
309
310
311
312
313
314
315
                )
            )
        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
            model,
            g,
            metapath_list,
            num_neighbors,
            features,
            labels,
            val_nid,
            loss_fn,
            args["batch_size"],
        )
maqy's avatar
maqy committed
316
317
        early_stop = stopper.step(val_loss.data.item(), val_acc, model)

318
319
320
321
322
        print(
            "Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
                epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1
            )
        )
maqy's avatar
maqy committed
323
324
325
326
327

        if early_stop:
            break

    stopper.load_checkpoint(model)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
        model,
        g,
        metapath_list,
        num_neighbors,
        features,
        labels,
        test_nid,
        loss_fn,
        args["batch_size"],
    )
    print(
        "Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
            test_loss.item(), test_acc, test_micro_f1, test_macro_f1
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser("mini-batch HAN")
    parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_neighbors", type=int, default=20)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--num_heads", type=list, default=[8])
    parser.add_argument("--hidden_units", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.6)
    parser.add_argument("--weight_decay", type=float, default=0.001)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--dataset", type=str, default="ACMRaw")
    parser.add_argument("--device", type=str, default="cuda:0")
maqy's avatar
maqy committed
360
361
362
363
364

    args = parser.parse_args().__dict__
    # set_random_seed(args['seed'])

    main(args)