"examples/advanced/vscode:/vscode.git/clone" did not exist on "bcf9923b32dacd8b1f06b156855b49b268b70dcc"
train.py 9.05 KB
Newer Older
1
2
import argparse
import time
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
7
8
import torch
import torch.nn.functional as F
from dataset import EllipticDataset
9
from model import EvolveGCNH, EvolveGCNO
10
11
12
13
from utils import Measure


def train(args, device):
14
15
16
17
18
19
    elliptic_dataset = EllipticDataset(
        raw_dir=args.raw_dir,
        processed_dir=args.processed_dir,
        self_loop=True,
        reverse_edge=True,
    )
20
21
22
23
24
25
26
27
28
29

    g, node_mask_by_time = elliptic_dataset.process()
    num_classes = elliptic_dataset.num_classes

    cached_subgraph = []
    cached_labeled_node_mask = []
    for i in range(len(node_mask_by_time)):
        # we add self loop edge when we construct full graph, not here
        node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])
        cached_subgraph.append(node_subgraph.to(device))
30
        valid_node_mask = node_subgraph.ndata["label"] >= 0
31
32
        cached_labeled_node_mask.append(valid_node_mask)

33
34
35
36
37
38
39
40
41
42
    if args.model == "EvolveGCN-O":
        model = EvolveGCNO(
            in_feats=int(g.ndata["feat"].shape[1]),
            n_hidden=args.n_hidden,
            num_layers=args.n_layers,
        )
    elif args.model == "EvolveGCN-H":
        model = EvolveGCNH(
            in_feats=int(g.ndata["feat"].shape[1]), num_layers=args.n_layers
        )
43
    else:
44
        return NotImplementedError("Unsupported model {}".format(args.model))
45
46
47
48
49
50
51
52
53
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # split train, valid, test(0-30,31-35,36-48)
    # train/valid/test split follow the paper.
    train_max_index = 30
    valid_max_index = 35
    test_max_index = 48
    time_window_size = args.n_hist_steps
54
    loss_class_weight = [float(w) for w in args.loss_class_weight.split(",")]
55
56
    loss_class_weight = torch.Tensor(loss_class_weight).to(device)

57
58
59
60
61
62
63
64
65
    train_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )
    valid_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )
    test_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )
66
67
68
69
70

    test_res_f1 = 0
    for epoch in range(args.num_epochs):
        model.train()
        for i in range(time_window_size, train_max_index + 1):
71
            g_list = cached_subgraph[i - time_window_size : i + 1]
72
73
74
            predictions = model(g_list)
            # get predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
75
76
77
78
79
80
81
82
            labels = (
                cached_subgraph[i]
                .ndata["label"][cached_labeled_node_mask[i]]
                .long()
            )
            loss = F.cross_entropy(
                predictions, labels, weight=loss_class_weight
            )
83
84
85
86
87
88
89
90
91
92
93
94
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_measure.append_measures(predictions, labels)

        # get each epoch measures during training.
        cl_precision, cl_recall, cl_f1 = train_measure.get_total_measure()
        train_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        train_measure.reset_info()

95
96
97
98
99
        print(
            "Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
            )
        )
100
101
102
103

        # eval
        model.eval()
        for i in range(train_max_index + 1, valid_max_index + 1):
104
            g_list = cached_subgraph[i - time_window_size : i + 1]
105
106
107
            predictions = model(g_list)
            # get node predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
108
109
110
111
112
            labels = (
                cached_subgraph[i]
                .ndata["label"][cached_labeled_node_mask[i]]
                .long()
            )
113
114
115
116
117
118
119
120
121

            valid_measure.append_measures(predictions, labels)

        # get each epoch measure during eval.
        cl_precision, cl_recall, cl_f1 = valid_measure.get_total_measure()
        valid_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        valid_measure.reset_info()

122
123
124
125
126
        print(
            "Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
            )
        )
127
128
129

        # early stop
        if epoch - valid_measure.target_best_f1_epoch >= args.patience:
130
131
132
133
134
            print(
                "Best eval Epoch {}, Cur Epoch {}".format(
                    valid_measure.target_best_f1_epoch, epoch
                )
            )
135
136
137
            break
        # if cur valid f1 score is best, do test
        if epoch == valid_measure.target_best_f1_epoch:
138
139
140
141
142
            print(
                "###################Epoch {} Test###################".format(
                    epoch
                )
            )
143
            for i in range(valid_max_index + 1, test_max_index + 1):
144
                g_list = cached_subgraph[i - time_window_size : i + 1]
145
146
147
                predictions = model(g_list)
                # get predictions which has label
                predictions = predictions[cached_labeled_node_mask[i]]
148
149
150
151
152
                labels = (
                    cached_subgraph[i]
                    .ndata["label"][cached_labeled_node_mask[i]]
                    .long()
                )
153
154
155
156

                test_measure.append_measures(predictions, labels)

            # we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.
157
158
159
160
161
162
163
164
165
166
167
168
169
            (
                cl_precisions,
                cl_recalls,
                cl_f1s,
            ) = test_measure.get_each_timestamp_measure()
            for index, (sub_p, sub_r, sub_f1) in enumerate(
                zip(cl_precisions, cl_recalls, cl_f1s)
            ):
                print(
                    "  Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                        valid_max_index + index + 2, sub_p, sub_r, sub_f1
                    )
                )
170
171
172
173
174
175
176
177
178

            # get each epoch measure during test.
            cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()
            test_measure.update_best_f1(cl_f1, epoch)
            # reset measures for next test
            test_measure.reset_info()

            test_res_f1 = cl_f1

179
180
181
182
183
            print(
                "  Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                    epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
                )
            )
184

185
186
187
188
189
    print(
        "Best test f1 is {}, in Epoch {}".format(
            test_measure.target_best_f1, test_measure.target_best_f1_epoch
        )
    )
190
    if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:
191
192
193
194
195
196
        print(
            "The Epoch get best Valid measure not get the best Test measure, "
            "please checkout the test result in Epoch {}, which f1 is {}".format(
                valid_measure.target_best_f1_epoch, test_res_f1
            )
        )
197
198
199
200


if __name__ == "__main__":
    argparser = argparse.ArgumentParser("EvolveGCN")
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    argparser.add_argument(
        "--model",
        type=str,
        default="EvolveGCN-O",
        help="We can choose EvolveGCN-O or EvolveGCN-H,"
        "but the EvolveGCN-H performance on Elliptic dataset is not good.",
    )
    argparser.add_argument(
        "--raw-dir",
        type=str,
        default="/home/Elliptic/elliptic_bitcoin_dataset/",
        help="Dir after unzip downloaded dataset, which contains 3 csv files.",
    )
    argparser.add_argument(
        "--processed-dir",
        type=str,
        default="/home/Elliptic/processed/",
        help="Dir to store processed raw data.",
    )
    argparser.add_argument(
        "--gpu",
        type=int,
        default=0,
        help="GPU device ID. Use -1 for CPU training.",
    )
    argparser.add_argument("--num-epochs", type=int, default=1000)
    argparser.add_argument("--n-hidden", type=int, default=256)
    argparser.add_argument("--n-layers", type=int, default=2)
    argparser.add_argument(
        "--n-hist-steps",
        type=int,
        default=5,
        help="If it is set to 5, it means in the first batch,"
        "we use historical data of 0-4 to predict the data of time 5.",
    )
    argparser.add_argument("--lr", type=float, default=0.001)
    argparser.add_argument(
        "--loss-class-weight",
        type=str,
        default="0.35,0.65",
        help="Weight for loss function. Follow the official code,"
        "we need to change it to 0.25, 0.75 when use EvolveGCN-H",
    )
    argparser.add_argument(
        "--eval-class-id",
        type=int,
        default=1,
        help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.",
    )
    argparser.add_argument(
        "--patience", type=int, default=100, help="Patience for early stopping."
    )
253
254
255
256

    args = argparser.parse_args()

    if args.gpu >= 0:
257
        device = torch.device("cuda:%d" % args.gpu)
258
    else:
259
        device = torch.device("cpu")
260
261
262
263

    start_time = time.perf_counter()
    train(args, device)
    print("train time is: {}".format(time.perf_counter() - start_time))