main.py 4.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import torch as th

import dgl
from dgl.dataloading import GraphDataLoader
import warnings

from dataset import load
warnings.filterwarnings('ignore')

from utils import linearsvc
from model import MVGRL

parser = argparse.ArgumentParser(description='mvgrl')

parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.')
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using cpu.')
parser.add_argument('--epochs', type=int, default=200, help=' Number of training periods.')
parser.add_argument('--patience', type=int, default=20, help='Early stopping steps.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate of mvgrl.')
parser.add_argument('--wd', type=float, default=0., help='Weight decay of mvgrl.')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
parser.add_argument('--n_layers', type=int, default=4, help='Number of GNN layers.')
parser.add_argument("--hid_dim", type=int, default=32, help='Hidden layer dim.')

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
    args.device = 'cuda:{}'.format(args.gpu)
else:
    args.device = 'cpu'


def collate(samples):
    ''' collate function for building the graph dataloader'''
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata['graph_id'] = graph_id

    return batched_graph, batched_diff_graph, batched_labels

if __name__ == '__main__':

    # Step 1: Prepare data =================================================================== #
    dataset = load(args.dataname)

    graphs, diff_graphs, labels = map(list, zip(*dataset))
    print('Number of graphs:', len(graphs))
    # generate a full-graph with all examples for evaluation

    wholegraph = dgl.batch(graphs)
    whole_dg = dgl.batch(diff_graphs)

    # create dataloader for batch training
    dataloader = GraphDataLoader(dataset,
                                 batch_size=args.batch_size,
                                 collate_fn=collate,
                                 drop_last=False,
                                 shuffle=True)

    in_dim = wholegraph.ndata['feat'].shape[1]

    # Step 2: Create model =================================================================== #
    model = MVGRL(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    print('===== Before training ======')

    wholegraph = wholegraph.to(args.device)
    whole_dg = whole_dg.to(args.device)
    wholefeat = wholegraph.ndata.pop('feat')
    whole_weight = whole_dg.edata.pop('edge_weight')

    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
    lbls = th.LongTensor(labels)
    acc_mean, acc_std = linearsvc(embs, lbls)
    print('accuracy_mean, {:.4f}'.format(acc_mean))

    best = float('inf')
    cnt_wait = 0
    # Step 4: Training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()

        for graph, diff_graph, label in dataloader:
            graph = graph.to(args.device)
            diff_graph = diff_graph.to(args.device)

            feat = graph.ndata['feat']
            graph_id = graph.ndata['graph_id']
            edge_weight = diff_graph.edata['edge_weight']
            n_graph = label.shape[0]

            optimizer.zero_grad()
            loss = model(graph, diff_graph, feat, edge_weight, graph_id)
            loss_all += loss.item()
            loss.backward()
            optimizer.step()

        print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all))

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            th.save(model.state_dict(), f'{args.dataname}.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print('Early stopping')
            break

    print('Training End')

    # Step 5:  Linear evaluation ========================================================== #
    model.load_state_dict(th.load(f'{args.dataname}.pkl'))
    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)

    acc_mean, acc_std = linearsvc(embs, lbls)
    print('accuracy_mean, {:.4f}'.format(acc_mean))