main.py 4.6 KB
Newer Older
1
import argparse
2
3
import warnings

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
import torch as th
7
from dataset import load
8
9
from dgl.dataloading import GraphDataLoader

10
warnings.filterwarnings("ignore")
11
12

from model import MVGRL
13
from utils import linearsvc
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
parser = argparse.ArgumentParser(description="mvgrl")

parser.add_argument(
    "--dataname", type=str, default="MUTAG", help="Name of dataset."
)
parser.add_argument(
    "--gpu", type=int, default=-1, help="GPU index. Default: -1, using cpu."
)
parser.add_argument(
    "--epochs", type=int, default=200, help=" Number of training periods."
)
parser.add_argument(
    "--patience", type=int, default=20, help="Early stopping steps."
)
parser.add_argument(
    "--lr", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
    "--wd", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument(
    "--n_layers", type=int, default=4, help="Number of GNN layers."
)
parser.add_argument("--hid_dim", type=int, default=32, help="Hidden layer dim.")
40
41
42
43
44

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
45
    args.device = "cuda:{}".format(args.gpu)
46
else:
47
    args.device = "cpu"
48
49
50


def collate(samples):
51
    """collate function for building the graph dataloader"""
52
53
54
55
56
57
58
59
60
61
62
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

63
    batched_graph.ndata["graph_id"] = graph_id
64
65
66

    return batched_graph, batched_diff_graph, batched_labels

67
68

if __name__ == "__main__":
69
70
71
72
    # Step 1: Prepare data =================================================================== #
    dataset = load(args.dataname)

    graphs, diff_graphs, labels = map(list, zip(*dataset))
73
    print("Number of graphs:", len(graphs))
74
75
76
77
78
79
    # generate a full-graph with all examples for evaluation

    wholegraph = dgl.batch(graphs)
    whole_dg = dgl.batch(diff_graphs)

    # create dataloader for batch training
80
81
82
83
84
85
86
    dataloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
87

88
    in_dim = wholegraph.ndata["feat"].shape[1]
89
90
91
92
93
94
95
96

    # Step 2: Create model =================================================================== #
    model = MVGRL(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

97
    print("===== Before training ======")
98
99
100

    wholegraph = wholegraph.to(args.device)
    whole_dg = whole_dg.to(args.device)
101
102
    wholefeat = wholegraph.ndata.pop("feat")
    whole_weight = whole_dg.edata.pop("edge_weight")
103
104
105
106

    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
    lbls = th.LongTensor(labels)
    acc_mean, acc_std = linearsvc(embs, lbls)
107
    print("accuracy_mean, {:.4f}".format(acc_mean))
108

109
    best = float("inf")
110
111
112
113
114
115
116
117
118
119
    cnt_wait = 0
    # Step 4: Training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()

        for graph, diff_graph, label in dataloader:
            graph = graph.to(args.device)
            diff_graph = diff_graph.to(args.device)

120
121
122
            feat = graph.ndata["feat"]
            graph_id = graph.ndata["graph_id"]
            edge_weight = diff_graph.edata["edge_weight"]
123
124
125
126
127
128
129
130
            n_graph = label.shape[0]

            optimizer.zero_grad()
            loss = model(graph, diff_graph, feat, edge_weight, graph_id)
            loss_all += loss.item()
            loss.backward()
            optimizer.step()

131
        print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
132
133
134
135
136

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
137
            th.save(model.state_dict(), f"{args.dataname}.pkl")
138
139
140
141
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
142
            print("Early stopping")
143
144
            break

145
    print("Training End")
146
147

    # Step 5:  Linear evaluation ========================================================== #
148
    model.load_state_dict(th.load(f"{args.dataname}.pkl"))
149
150
151
    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)

    acc_mean, acc_std = linearsvc(embs, lbls)
152
    print("accuracy_mean, {:.4f}".format(acc_mean))