train.py 4.16 KB
Newer Older
1
2
3
4
import argparse
import os
import pickle
import time
5

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import dgl

8
import numpy as np
9
10
11
import torch
import torch.optim as optim
from dataset import LanderDataset
12
from models import LANDER
13
14
15
16
17
18

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
19
20
21
22
23
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--test_data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")
24
25

# KNN
26
parser.add_argument("--knn_k", type=str, default="10")
27
28

# Model
29
30
31
32
33
34
35
36
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
37
38

# Training
39
40
41
42
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)
43
44
45
46
47
48

args = parser.parse_args()

###########################
# Environment Configuration
if torch.cuda.is_available():
49
    device = torch.device("cuda")
50
else:
51
    device = torch.device("cpu")
52

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
53

54
55
56
##################
# Data Preparation
def prepare_dataset_graphs(data_path, k_list, lvl_list):
57
    with open(data_path, "rb") as f:
58
59
60
        features, labels = pickle.load(f)
    gs = []
    for k, l in zip(k_list, lvl_list):
61
62
63
64
65
66
67
        dataset = LanderDataset(
            features=features,
            labels=labels,
            k=k,
            levels=l,
            faiss_gpu=args.faiss_gpu,
        )
68
69
70
        gs += [g.to(device) for g in dataset.gs]
    return gs

71
72
73

k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
74
75
76
77
78
gs = prepare_dataset_graphs(args.data_path, k_list, lvl_list)
test_gs = prepare_dataset_graphs(args.test_data_path, k_list, lvl_list)

##################
# Model Definition
79
80
81
82
83
84
85
86
87
88
89
90
feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(
    feature_dim=feature_dim,
    nhid=args.hidden,
    num_conv=args.num_conv,
    dropout=args.dropout,
    use_GAT=args.gat,
    K=args.gat_k,
    balance=args.balance,
    use_cluster_feat=args.use_cluster_feat,
    use_focal_loss=args.use_focal_loss,
)
91
92
93
94
95
96
97
model = model.to(device)
model.train()
best_model = None
best_loss = np.Inf

#################
# Hyperparameters
98
99
100
101
102
103
104
105
106
opt = optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, T_max=args.epochs, eta_min=1e-5
)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

###############
# Training Loop
for epoch in range(args.epochs):
    all_loss_den_val = 0
    all_loss_conn_val = 0
    for g in gs:
        opt.zero_grad()
        g = model(g)
        loss, loss_den_val, loss_conn_val = model.compute_loss(g)
        all_loss_den_val += loss_den_val
        all_loss_conn_val += loss_conn_val
        loss.backward()
        opt.step()
    scheduler.step()
122
123
124
125
    print(
        "Training, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
        % (epoch, all_loss_den_val, all_loss_conn_val)
    )
126
127
128
129
130
131
132
133
134
    # Report test
    all_test_loss_den_val = 0
    all_test_loss_conn_val = 0
    with torch.no_grad():
        for g in test_gs:
            g = model(g)
            loss, loss_den_val, loss_conn_val = model.compute_loss(g)
            all_test_loss_den_val += loss_den_val
            all_test_loss_conn_val += loss_conn_val
135
136
137
138
    print(
        "Testing, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
        % (epoch, all_test_loss_den_val, all_test_loss_conn_val)
    )
139
140
    if all_test_loss_conn_val + all_test_loss_den_val < best_loss:
        best_loss = all_test_loss_conn_val + all_test_loss_den_val
141
142
        print("New best epoch", epoch)
        torch.save(model.state_dict(), args.model_filename + "_best")
143
144
145
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)