train_subg.py 5.63 KB
Newer Older
1
2
3
4
import argparse
import os
import pickle
import time
5

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import dgl

8
import numpy as np
9
10
11
import torch
import torch.optim as optim
from dataset import LanderDataset
12
from models import LANDER
13
14
15
16
17
18

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
19
20
21
22
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")
23
24

# KNN
25
26
parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument("--num_workers", type=int, default=0)
27
28

# Model
29
30
31
32
33
34
35
36
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
37
38

# Training
39
40
41
42
43
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)
44
45
46
47
48
49
50

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
51
    device = torch.device("cuda")
52
else:
53
    device = torch.device("cpu")
54
55
56

##################
# Data Preparation
57
with open(args.data_path, "rb") as f:
58
59
    features, labels = pickle.load(f)

60
61
k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
62
63
64
65
gs = []
nbrs = []
ks = []
for k, l in zip(k_list, lvl_list):
66
67
68
69
70
71
72
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=k,
        levels=l,
        faiss_gpu=args.faiss_gpu,
    )
73
74
75
76
    gs += [g for g in dataset.gs]
    ks += [k for g in dataset.gs]
    nbrs += [nbr for nbr in dataset.nbrs]

77
78
print("Dataset Prepared.")

79
80

def set_train_sampler_loader(g, k):
81
    fanouts = [k - 1 for i in range(args.num_conv + 1)]
82
83
    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
    # fix the number of edges
84
    train_dataloader = dgl.dataloading.DataLoader(
85
86
87
        g,
        torch.arange(g.number_of_nodes()),
        sampler,
88
89
90
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
91
        num_workers=args.num_workers,
92
93
94
    )
    return train_dataloader

95

96
97
98
99
100
101
102
train_loaders = []
for gidx, g in enumerate(gs):
    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
    train_loaders.append(train_dataloader)

##################
# Model Definition
103
104
105
106
107
108
109
110
111
112
113
114
feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(
    feature_dim=feature_dim,
    nhid=args.hidden,
    num_conv=args.num_conv,
    dropout=args.dropout,
    use_GAT=args.gat,
    K=args.gat_k,
    balance=args.balance,
    use_cluster_feat=args.use_cluster_feat,
    use_focal_loss=args.use_focal_loss,
)
115
116
117
118
119
model = model.to(device)
model.train()

#################
# Hyperparameters
120
121
122
123
124
125
opt = optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)
126
127
128
129
130

# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
131
132
133
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
)
134

135
print("Start Training.")
136
137
138
139
140
141
142
143
144
145
146
147

###############
# Training Loop
for epoch in range(args.epochs):
    loss_den_val_total = []
    loss_conn_val_total = []
    loss_val_total = []
    for batch in range(num_batch_per_loader):
        for loader_id in range(num_loaders):
            try:
                minibatch = next(train_loaders[loader_id])
            except:
148
149
150
                train_loaders[loader_id] = iter(
                    set_train_sampler_loader(gs[loader_id], ks[loader_id])
                )
151
152
153
154
155
156
157
                minibatch = next(train_loaders[loader_id])
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            # get the feature for the input_nodes
            opt.zero_grad()
            output_bipartite = model(bipartites)
158
159
160
            loss, loss_den_val, loss_conn_val = model.compute_loss(
                output_bipartite
            )
161
162
163
164
165
166
            loss_den_val_total.append(loss_den_val)
            loss_conn_val_total.append(loss_conn_val)
            loss_val_total.append(loss.item())
            loss.backward()
            opt.step()
            if (batch + 1) % 10 == 0:
167
168
169
170
171
172
173
174
175
176
177
178
179
                print(
                    "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
                    % (
                        epoch,
                        batch,
                        num_batch_per_loader,
                        loader_id,
                        num_loaders,
                        loss.item(),
                        loss_den_val,
                        loss_conn_val,
                    )
                )
180
            scheduler.step()
181
182
183
184
185
186
187
188
189
    print(
        "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
        % (
            epoch,
            np.array(loss_val_total).mean(),
            np.array(loss_den_val_total).mean(),
            np.array(loss_conn_val_total).mean(),
        )
    )
190
191
192
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)