unsupervised.py 4.81 KB
Newer Older
1
2
import argparse

3
import torch as th
4
5
6
from evaluate_embedding import evaluate_embedding
from model import InfoGraph

7
8
9
10
11
12
import dgl
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader


def argument():
13
    parser = argparse.ArgumentParser(description="InfoGraph")
14
    # data source params
15
16
17
    parser.add_argument(
        "--dataname", type=str, default="MUTAG", help="Name of dataset."
    )
18
19

    # training params
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=20, help="Training epochs."
    )
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Training batch size."
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument(
        "--log_interval",
        type=int,
        default=1,
        help="Interval between two evaluations.",
    )
36
37

    # model params
38
39
40
41
42
43
44
45
46
    parser.add_argument(
        "--n_layers",
        type=int,
        default=3,
        help="Number of graph convolution layers before each pooling.",
    )
    parser.add_argument(
        "--hid_dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
47
48
49
50
51

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
52
        args.device = "cuda:{}".format(args.gpu)
53
    else:
54
        args.device = "cpu"
55
56
57

    return args

58

59
def collate(samples):
60
61
    """collate function for building graph dataloader"""

62
63
64
65
66
67
68
69
70
71
    graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

72
    batched_graph.ndata["graph_id"] = graph_id
73
74
75
76

    return batched_graph, batched_labels


77
if __name__ == "__main__":
78
79
80
81
82
83

    # Step 1: Prepare graph data   ===================================== #
    args = argument()
    print(args)

    # load dataset from dgl.data.GINDataset
84
    dataset = GINDataset(args.dataname, self_loop=False)
85
86
87
88
89
90

    # get graphs and labels
    graphs, labels = map(list, zip(*dataset))

    # generate a full-graph with all examples for evaluation
    wholegraph = dgl.batch(graphs)
91
    wholegraph.ndata["attr"] = wholegraph.ndata["attr"].to(th.float32)
92
93

    # create dataloader for batch training
94
95
96
97
98
99
100
    dataloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
101

102
    in_dim = wholegraph.ndata["attr"].shape[1]
103
104
105
106
107
108
109

    # Step 2: Create model =================================================================== #
    model = InfoGraph(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
110
111
112

    print("===== Before training ======")

113
    wholegraph = wholegraph.to(args.device)
114
115
    wholefeat = wholegraph.ndata["attr"]

116
117
118
    emb = model.get_embedding(wholegraph, wholefeat).cpu()
    res = evaluate_embedding(emb, labels, args.device)

119
120
121
122
    """ Evaluate the initialized embeddings """
    """ using logistic regression and SVM(non-linear) """
    print("logreg {:4f}, svc {:4f}".format(res[0], res[1]))

123
124
125
126
127
128
129
130
131
    best_logreg = 0
    best_logreg_epoch = 0
    best_svc = 0
    best_svc_epoch = 0

    # Step 4: training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()
132

133
        for graph, label in dataloader:
134

135
            graph = graph.to(args.device)
136
137
138
            feat = graph.ndata["attr"]
            graph_id = graph.ndata["graph_id"]

139
            n_graph = label.shape[0]
140

141
142
143
144
145
            optimizer.zero_grad()
            loss = model(graph, feat, graph_id)
            loss.backward()
            optimizer.step()
            loss_all += loss.item()
146
147
148

        print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))

149
150
151
152
153
154
        if epoch % args.log_interval == 0:

            # evaluate embeddings
            model.eval()
            emb = model.get_embedding(wholegraph, wholefeat).cpu()
            res = evaluate_embedding(emb, labels, args.device)
155

156
157
158
159
160
161
162
163
            if res[0] > best_logreg:
                best_logreg = res[0]
                best_logreg_epoch = epoch

            if res[1] > best_svc:
                best_svc = res[1]
                best_svc_epoch = epoch

164
165
166
167
168
            print(
                "best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}".format(
                    best_logreg, best_logreg_epoch, best_svc, best_svc_epoch
                )
            )
169

170
171
    print("Training End")
    print("best logreg {:4f} ,best svc {:4f}".format(best_logreg, best_svc))