train_subg_inat.py 6.51 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import argparse, os, pickle, time
2
3
import random

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
import sys
5
6

import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8

import numpy as np
9
10
11
12
13
import torch
import torch.optim as optim

sys.path.append("..")
from dataset import LanderDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14
from models import LANDER
15
16
17
18
19
20

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
21
22
23
24
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")
25
26

# KNN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
27
28
parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument("--num_workers", type=int, default=0)
29
30

# Model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
32
33
34
35
36
37
38
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
39
40

# Training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
41
42
43
44
45
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)
46
47
48
49
50
51
52

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
53
    device = torch.device("cuda")
54
else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
55
    device = torch.device("cpu")
56
57
58
59
60
61
62
63
64
65
66
67
68
69


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# setup_seed(20)

##################
# Data Preparation
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
70
with open(args.data_path, "rb") as f:
71
72
73
74
75
76
77
78
    path2idx, features, labels, _, masks = pickle.load(f)
    # lidx = np.where(masks==0)
    # features = features[lidx]
    # labels = labels[lidx]
    print("features.shape:", features.shape)
    print("labels.shape:", labels.shape)


Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
79
80
k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
81
82
83
84
85
86
87
gs = []
nbrs = []
ks = []
datasets = []
for k, l in zip(k_list, lvl_list):
    print("k:", k)
    print("levels:", l)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
89
90
91
92
93
94
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=k,
        levels=l,
        faiss_gpu=args.faiss_gpu,
    )
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    gs += [g for g in dataset.gs]
    ks += [k for g in dataset.gs]
    nbrs += [nbr for nbr in dataset.nbrs]
    datasets.append(dataset)

# with open("./dataset.pkl", 'rb') as f:
#     datasets = pickle.load(f)
# for i in range(len(datasets)):
#     dataset = datasets[i]
#     k = k_list[i]
#     gs += [g for g in dataset.gs]
#     ks += [k for g in dataset.gs]
#     nbrs += [nbr for nbr in dataset.nbrs]


Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
with open("./dataset.pkl", "wb") as f:
111
112
    pickle.dump(datasets, f)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
114
print("Dataset Prepared.")

115
116

def set_train_sampler_loader(g, k):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
    fanouts = [k - 1 for i in range(args.num_conv + 1)]
118
119
    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
    # fix the number of edges
120
    train_dataloader = dgl.dataloading.DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
123
        g,
        torch.arange(g.number_of_nodes()),
        sampler,
124
125
126
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
        num_workers=args.num_workers,
128
129
130
    )
    return train_dataloader

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
131

132
133
134
135
136
137
138
train_loaders = []
for gidx, g in enumerate(gs):
    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
    train_loaders.append(train_dataloader)

##################
# Model Definition
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
feature_dim = gs[0].ndata["features"].shape[1]
140
print("feature dimension:", feature_dim)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
141
142
143
144
145
146
147
148
149
150
151
model = LANDER(
    feature_dim=feature_dim,
    nhid=args.hidden,
    num_conv=args.num_conv,
    dropout=args.dropout,
    use_GAT=args.gat,
    K=args.gat_k,
    balance=args.balance,
    use_cluster_feat=args.use_cluster_feat,
    use_focal_loss=args.use_focal_loss,
)
152
153
154
155
156
model = model.to(device)
model.train()

#################
# Hyperparameters
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
158
159
160
161
162
opt = optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)
163
164
165
166
167

# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
168
169
170
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
)
171

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
172
print("Start Training.")
173
174
175
176
177
178
179
180
181
182
183
184

###############
# Training Loop
for epoch in range(args.epochs):
    loss_den_val_total = []
    loss_conn_val_total = []
    loss_val_total = []
    for batch in range(num_batch_per_loader):
        for loader_id in range(num_loaders):
            try:
                minibatch = next(train_loaders[loader_id])
            except:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
185
186
187
                train_loaders[loader_id] = iter(
                    set_train_sampler_loader(gs[loader_id], ks[loader_id])
                )
188
189
190
191
192
193
194
                minibatch = next(train_loaders[loader_id])
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            # get the feature for the input_nodes
            opt.zero_grad()
            output_bipartite = model(bipartites)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
196
197
            loss, loss_den_val, loss_conn_val = model.compute_loss(
                output_bipartite
            )
198
199
200
201
202
203
            loss_den_val_total.append(loss_den_val)
            loss_conn_val_total.append(loss_conn_val)
            loss_val_total.append(loss.item())
            loss.backward()
            opt.step()
            if (batch + 1) % 10 == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
204
205
206
207
208
209
210
211
212
213
214
215
216
                print(
                    "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
                    % (
                        epoch,
                        batch,
                        num_batch_per_loader,
                        loader_id,
                        num_loaders,
                        loss.item(),
                        loss_den_val,
                        loss_conn_val,
                    )
                )
217
            scheduler.step()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
218
219
220
221
222
223
224
225
226
    print(
        "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
        % (
            epoch,
            np.array(loss_val_total).mean(),
            np.array(loss_den_val_total).mean(),
            np.array(loss_conn_val_total).mean(),
        )
    )
227
228
229
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)