main.py 4.6 KB
Newer Older
1
import argparse
2
3
import warnings

4
import torch as th
5
from dataset import load
6
7
8
9

import dgl
from dgl.dataloading import GraphDataLoader

10
warnings.filterwarnings("ignore")
11
12

from model import MVGRL
13
from utils import linearsvc
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
parser = argparse.ArgumentParser(description="mvgrl")

parser.add_argument(
    "--dataname", type=str, default="MUTAG", help="Name of dataset."
)
parser.add_argument(
    "--gpu", type=int, default=-1, help="GPU index. Default: -1, using cpu."
)
parser.add_argument(
    "--epochs", type=int, default=200, help=" Number of training periods."
)
parser.add_argument(
    "--patience", type=int, default=20, help="Early stopping steps."
)
parser.add_argument(
    "--lr", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
    "--wd", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument(
    "--n_layers", type=int, default=4, help="Number of GNN layers."
)
parser.add_argument("--hid_dim", type=int, default=32, help="Hidden layer dim.")
40
41
42
43
44

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
45
    args.device = "cuda:{}".format(args.gpu)
46
else:
47
    args.device = "cpu"
48
49
50


def collate(samples):
51
    """collate function for building the graph dataloader"""
52
53
54
55
56
57
58
59
60
61
62
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

63
    batched_graph.ndata["graph_id"] = graph_id
64
65
66

    return batched_graph, batched_diff_graph, batched_labels

67
68

if __name__ == "__main__":
69
70
71
72
73

    # Step 1: Prepare data =================================================================== #
    dataset = load(args.dataname)

    graphs, diff_graphs, labels = map(list, zip(*dataset))
74
    print("Number of graphs:", len(graphs))
75
76
77
78
79
80
    # generate a full-graph with all examples for evaluation

    wholegraph = dgl.batch(graphs)
    whole_dg = dgl.batch(diff_graphs)

    # create dataloader for batch training
81
82
83
84
85
86
87
    dataloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
88

89
    in_dim = wholegraph.ndata["feat"].shape[1]
90
91
92
93
94
95
96
97

    # Step 2: Create model =================================================================== #
    model = MVGRL(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

98
    print("===== Before training ======")
99
100
101

    wholegraph = wholegraph.to(args.device)
    whole_dg = whole_dg.to(args.device)
102
103
    wholefeat = wholegraph.ndata.pop("feat")
    whole_weight = whole_dg.edata.pop("edge_weight")
104
105
106
107

    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
    lbls = th.LongTensor(labels)
    acc_mean, acc_std = linearsvc(embs, lbls)
108
    print("accuracy_mean, {:.4f}".format(acc_mean))
109

110
    best = float("inf")
111
112
113
114
115
116
117
118
119
120
    cnt_wait = 0
    # Step 4: Training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()

        for graph, diff_graph, label in dataloader:
            graph = graph.to(args.device)
            diff_graph = diff_graph.to(args.device)

121
122
123
            feat = graph.ndata["feat"]
            graph_id = graph.ndata["graph_id"]
            edge_weight = diff_graph.edata["edge_weight"]
124
125
126
127
128
129
130
131
            n_graph = label.shape[0]

            optimizer.zero_grad()
            loss = model(graph, diff_graph, feat, edge_weight, graph_id)
            loss_all += loss.item()
            loss.backward()
            optimizer.step()

132
        print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
133
134
135
136
137

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
138
            th.save(model.state_dict(), f"{args.dataname}.pkl")
139
140
141
142
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
143
            print("Early stopping")
144
145
            break

146
    print("Training End")
147
148

    # Step 5:  Linear evaluation ========================================================== #
149
    model.load_state_dict(th.load(f"{args.dataname}.pkl"))
150
151
152
    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)

    acc_mean, acc_std = linearsvc(embs, lbls)
153
    print("accuracy_mean, {:.4f}".format(acc_mean))