train.py 8.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import time
import dgl
import torch
import torch.nn.functional as F

from dataset import EllipticDataset
from model import EvolveGCNO, EvolveGCNH
from utils import Measure


def train(args, device):
    elliptic_dataset = EllipticDataset(raw_dir=args.raw_dir,
                                       processed_dir=args.processed_dir,
                                       self_loop=True,
                                       reverse_edge=True)

    g, node_mask_by_time = elliptic_dataset.process()
    num_classes = elliptic_dataset.num_classes

    cached_subgraph = []
    cached_labeled_node_mask = []
    for i in range(len(node_mask_by_time)):
        # we add self loop edge when we construct full graph, not here
        node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])
        cached_subgraph.append(node_subgraph.to(device))
        valid_node_mask = node_subgraph.ndata['label'] >= 0
        cached_labeled_node_mask.append(valid_node_mask)

    if args.model == 'EvolveGCN-O':
        model = EvolveGCNO(in_feats=int(g.ndata['feat'].shape[1]),
                           n_hidden=args.n_hidden,
                           num_layers=args.n_layers)
    elif args.model == 'EvolveGCN-H':
        model = EvolveGCNH(in_feats=int(g.ndata['feat'].shape[1]),
                           num_layers=args.n_layers)
    else:
        return NotImplementedError('Unsupported model {}'.format(args.model))
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # split train, valid, test(0-30,31-35,36-48)
    # train/valid/test split follow the paper.
    train_max_index = 30
    valid_max_index = 35
    test_max_index = 48
    time_window_size = args.n_hist_steps
    loss_class_weight = [float(w) for w in args.loss_class_weight.split(',')]
    loss_class_weight = torch.Tensor(loss_class_weight).to(device)

    train_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)
    valid_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)
    test_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)

    test_res_f1 = 0
    for epoch in range(args.num_epochs):
        model.train()
        for i in range(time_window_size, train_max_index + 1):
            g_list = cached_subgraph[i - time_window_size:i + 1]
            predictions = model(g_list)
            # get predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
            labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()
            loss = F.cross_entropy(predictions, labels, weight=loss_class_weight)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_measure.append_measures(predictions, labels)

        # get each epoch measures during training.
        cl_precision, cl_recall, cl_f1 = train_measure.get_total_measure()
        train_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        train_measure.reset_info()

        print("Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
              .format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))

        # eval
        model.eval()
        for i in range(train_max_index + 1, valid_max_index + 1):
            g_list = cached_subgraph[i - time_window_size:i + 1]
            predictions = model(g_list)
            # get node predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
            labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()

            valid_measure.append_measures(predictions, labels)

        # get each epoch measure during eval.
        cl_precision, cl_recall, cl_f1 = valid_measure.get_total_measure()
        valid_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        valid_measure.reset_info()

        print("Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
              .format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))

        # early stop
        if epoch - valid_measure.target_best_f1_epoch >= args.patience:
            print("Best eval Epoch {}, Cur Epoch {}".format(valid_measure.target_best_f1_epoch, epoch))
            break
        # if cur valid f1 score is best, do test
        if epoch == valid_measure.target_best_f1_epoch:
            print("###################Epoch {} Test###################".format(epoch))
            for i in range(valid_max_index + 1, test_max_index + 1):
                g_list = cached_subgraph[i - time_window_size:i + 1]
                predictions = model(g_list)
                # get predictions which has label
                predictions = predictions[cached_labeled_node_mask[i]]
                labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()

                test_measure.append_measures(predictions, labels)

            # we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.
            cl_precisions, cl_recalls, cl_f1s = test_measure.get_each_timestamp_measure()
            for index, (sub_p, sub_r, sub_f1) in enumerate(zip(cl_precisions, cl_recalls, cl_f1s)):
                print("  Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
                      .format(valid_max_index + index + 2, sub_p, sub_r, sub_f1))

            # get each epoch measure during test.
            cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()
            test_measure.update_best_f1(cl_f1, epoch)
            # reset measures for next test
            test_measure.reset_info()

            test_res_f1 = cl_f1

            print("  Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
                  .format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))

    print("Best test f1 is {}, in Epoch {}"
          .format(test_measure.target_best_f1, test_measure.target_best_f1_epoch))
    if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:
        print("The Epoch get best Valid measure not get the best Test measure, "
              "please checkout the test result in Epoch {}, which f1 is {}"
              .format(valid_measure.target_best_f1_epoch, test_res_f1))


if __name__ == "__main__":
    argparser = argparse.ArgumentParser("EvolveGCN")
    argparser.add_argument('--model', type=str, default='EvolveGCN-O',
                           help='We can choose EvolveGCN-O or EvolveGCN-H,'
                                'but the EvolveGCN-H performance on Elliptic dataset is not good.')
    argparser.add_argument('--raw-dir', type=str,
                           default='/home/Elliptic/elliptic_bitcoin_dataset/',
                           help="Dir after unzip downloaded dataset, which contains 3 csv files.")
    argparser.add_argument('--processed-dir', type=str,
                           default='/home/Elliptic/processed/',
                           help="Dir to store processed raw data.")
    argparser.add_argument('--gpu', type=int, default=0,
                           help="GPU device ID. Use -1 for CPU training.")
    argparser.add_argument('--num-epochs', type=int, default=1000)
    argparser.add_argument('--n-hidden', type=int, default=256)
    argparser.add_argument('--n-layers', type=int, default=2)
    argparser.add_argument('--n-hist-steps', type=int, default=5,
                           help="If it is set to 5, it means in the first batch,"
                                "we use historical data of 0-4 to predict the data of time 5.")
    argparser.add_argument('--lr', type=float, default=0.001)
    argparser.add_argument('--loss-class-weight', type=str, default='0.35,0.65',
                           help='Weight for loss function. Follow the official code,'
                                'we need to change it to 0.25, 0.75 when use EvolveGCN-H')
    argparser.add_argument('--eval-class-id', type=int, default=1,
                           help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.")
    argparser.add_argument('--patience', type=int, default=100,
                           help="Patience for early stopping.")

    args = argparser.parse_args()

    if args.gpu >= 0:
        device = torch.device('cuda:%d' % args.gpu)
    else:
        device = torch.device('cpu')

    start_time = time.perf_counter()
    train(args, device)
    print("train time is: {}".format(time.perf_counter() - start_time))