train.py 11.2 KB
Newer Older
1
2
3
4
"""Training GCMC model on the MovieLens data set.

The script loads the full graph to the training device.
"""
Zihao Ye's avatar
Zihao Ye committed
5
6
import argparse
import logging
7
import os
Zihao Ye's avatar
Zihao Ye committed
8
9
import random
import string
10
11
import time

Zihao Ye's avatar
Zihao Ye committed
12
13
14
15
import numpy as np
import torch as th
import torch.nn as nn
from data import MovieLens
16
from model import BiDecoder, GCMCLayer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
18
19
20
21
22
23
from utils import (
    get_activation,
    get_optimizer,
    MetricLogger,
    torch_net_info,
    torch_total_param_num,
)
24

Zihao Ye's avatar
Zihao Ye committed
25
26
27
28
29

class Net(nn.Module):
    def __init__(self, args):
        super(Net, self).__init__()
        self._act = get_activation(args.model_activation)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        self.encoder = GCMCLayer(
            args.rating_vals,
            args.src_in_units,
            args.dst_in_units,
            args.gcn_agg_units,
            args.gcn_out_units,
            args.gcn_dropout,
            args.gcn_agg_accum,
            agg_act=self._act,
            share_user_item_param=args.share_param,
            device=args.device,
        )
        self.decoder = BiDecoder(
            in_units=args.gcn_out_units,
            num_classes=len(args.rating_vals),
            num_basis=args.gen_r_num_basis_func,
        )
Zihao Ye's avatar
Zihao Ye committed
47
48

    def forward(self, enc_graph, dec_graph, ufeat, ifeat):
49
        user_out, movie_out = self.encoder(enc_graph, ufeat, ifeat)
Zihao Ye's avatar
Zihao Ye committed
50
51
52
        pred_ratings = self.decoder(dec_graph, user_out, movie_out)
        return pred_ratings

53
54

def evaluate(args, net, dataset, segment="valid"):
Zihao Ye's avatar
Zihao Ye committed
55
    possible_rating_values = dataset.possible_rating_values
56
57
58
    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(
        args.device
    )
Zihao Ye's avatar
Zihao Ye committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    if segment == "valid":
        rating_values = dataset.valid_truths
        enc_graph = dataset.valid_enc_graph
        dec_graph = dataset.valid_dec_graph
    elif segment == "test":
        rating_values = dataset.test_truths
        enc_graph = dataset.test_enc_graph
        dec_graph = dataset.test_dec_graph
    else:
        raise NotImplementedError

    # Evaluate RMSE
    net.eval()
    with th.no_grad():
74
75
76
77
78
79
80
        pred_ratings = net(
            enc_graph, dec_graph, dataset.user_feature, dataset.movie_feature
        )
    real_pred_ratings = (
        th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)
    ).sum(dim=1)
    rmse = ((real_pred_ratings - rating_values) ** 2.0).mean().item()
Zihao Ye's avatar
Zihao Ye committed
81
82
83
    rmse = np.sqrt(rmse)
    return rmse

84

Zihao Ye's avatar
Zihao Ye committed
85
86
def train(args):
    print(args)
87
88
89
90
91
92
93
94
    dataset = MovieLens(
        args.data_name,
        args.device,
        use_one_hot_fea=args.use_one_hot_fea,
        symm=args.gcn_agg_norm_symm,
        test_ratio=args.data_test_ratio,
        valid_ratio=args.data_valid_ratio,
    )
Zihao Ye's avatar
Zihao Ye committed
95
96
97
98
99
100
101
102
103
    print("Loading data finished ...\n")

    args.src_in_units = dataset.user_feature_shape[1]
    args.dst_in_units = dataset.movie_feature_shape[1]
    args.rating_vals = dataset.possible_rating_values

    ### build the net
    net = Net(args=args)
    net = net.to(args.device)
104
105
106
    nd_possible_rating_values = th.FloatTensor(
        dataset.possible_rating_values
    ).to(args.device)
Zihao Ye's avatar
Zihao Ye committed
107
108
    rating_loss_net = nn.CrossEntropyLoss()
    learning_rate = args.train_lr
109
110
111
    optimizer = get_optimizer(args.train_optimizer)(
        net.parameters(), lr=learning_rate
    )
Zihao Ye's avatar
Zihao Ye committed
112
113
114
115
116
117
118
    print("Loading network finished ...\n")

    ### perpare training data
    train_gt_labels = dataset.train_labels
    train_gt_ratings = dataset.train_truths

    ### prepare the logger
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    train_loss_logger = MetricLogger(
        ["iter", "loss", "rmse"],
        ["%d", "%.4f", "%.4f"],
        os.path.join(args.save_dir, "train_loss%d.csv" % args.save_id),
    )
    valid_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "valid_loss%d.csv" % args.save_id),
    )
    test_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "test_loss%d.csv" % args.save_id),
    )
Zihao Ye's avatar
Zihao Ye committed
134
135
136
137
138
139
140
141
142

    ### declare the loss information
    best_valid_rmse = np.inf
    no_better_valid = 0
    best_iter = -1
    count_rmse = 0
    count_num = 0
    count_loss = 0

143
144
    dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)
    dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)
145
    dataset.valid_enc_graph = dataset.train_enc_graph
146
147
148
    dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)
    dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)
    dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)
149

Zihao Ye's avatar
Zihao Ye committed
150
151
152
153
154
155
    print("Start training ...")
    dur = []
    for iter_idx in range(1, args.train_max_iter):
        if iter_idx > 3:
            t0 = time.time()
        net.train()
156
157
158
159
160
161
        pred_ratings = net(
            dataset.train_enc_graph,
            dataset.train_dec_graph,
            dataset.user_feature,
            dataset.movie_feature,
        )
Zihao Ye's avatar
Zihao Ye committed
162
163
164
165
166
167
168
169
170
171
172
173
        loss = rating_loss_net(pred_ratings, train_gt_labels).mean()
        count_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
        optimizer.step()

        if iter_idx > 3:
            dur.append(time.time() - t0)

        if iter_idx == 1:
            print("Total #Param of net: %d" % (torch_total_param_num(net)))
174
175
176
177
178
179
180
181
            print(
                torch_net_info(
                    net,
                    save_path=os.path.join(
                        args.save_dir, "net%d.txt" % args.save_id
                    ),
                )
            )
Zihao Ye's avatar
Zihao Ye committed
182

183
184
185
186
        real_pred_ratings = (
            th.softmax(pred_ratings, dim=1)
            * nd_possible_rating_values.view(1, -1)
        ).sum(dim=1)
Zihao Ye's avatar
Zihao Ye committed
187
188
189
190
191
        rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum()
        count_rmse += rmse.item()
        count_num += pred_ratings.shape[0]

        if iter_idx % args.train_log_interval == 0:
192
193
194
195
196
            train_loss_logger.log(
                iter=iter_idx,
                loss=count_loss / (iter_idx + 1),
                rmse=count_rmse / count_num,
            )
197
198
199
200
            logging_str = "Iter={:4d}, loss={:.4f}, rmse={:.4f}".format(
                iter_idx,
                count_loss / iter_idx,
                count_rmse / count_num,
201
            )
202
203
204
            if iter_idx > 3:
                logging_str += ", time={:.4f}".format(np.average(dur))

Zihao Ye's avatar
Zihao Ye committed
205
206
207
208
            count_rmse = 0
            count_num = 0

        if iter_idx % args.train_valid_interval == 0:
209
210
211
212
213
            valid_rmse = evaluate(
                args=args, net=net, dataset=dataset, segment="valid"
            )
            valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
            logging_str += ",\tVal RMSE={:.4f}".format(valid_rmse)
Zihao Ye's avatar
Zihao Ye committed
214
215
216
217
218

            if valid_rmse < best_valid_rmse:
                best_valid_rmse = valid_rmse
                no_better_valid = 0
                best_iter = iter_idx
219
220
221
                test_rmse = evaluate(
                    args=args, net=net, dataset=dataset, segment="test"
                )
Zihao Ye's avatar
Zihao Ye committed
222
223
                best_test_rmse = test_rmse
                test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
224
                logging_str += ", Test RMSE={:.4f}".format(test_rmse)
Zihao Ye's avatar
Zihao Ye committed
225
226
            else:
                no_better_valid += 1
227
228
229
230
231
232
233
                if (
                    no_better_valid > args.train_early_stopping_patience
                    and learning_rate <= args.train_min_lr
                ):
                    logging.info(
                        "Early stopping threshold reached. Stop training."
                    )
Zihao Ye's avatar
Zihao Ye committed
234
235
                    break
                if no_better_valid > args.train_decay_patience:
236
237
238
239
                    new_lr = max(
                        learning_rate * args.train_lr_decay_factor,
                        args.train_min_lr,
                    )
Zihao Ye's avatar
Zihao Ye committed
240
                    if new_lr < learning_rate:
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
241
                        learning_rate = new_lr
Zihao Ye's avatar
Zihao Ye committed
242
243
                        logging.info("\tChange the LR to %g" % new_lr)
                        for p in optimizer.param_groups:
244
                            p["lr"] = learning_rate
Zihao Ye's avatar
Zihao Ye committed
245
                        no_better_valid = 0
246
        if iter_idx % args.train_log_interval == 0:
Zihao Ye's avatar
Zihao Ye committed
247
            print(logging_str)
248
249
250
251
252
    print(
        "Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
            best_iter, best_valid_rmse, best_test_rmse
        )
    )
Zihao Ye's avatar
Zihao Ye committed
253
254
255
256
257
258
    train_loss_logger.close()
    valid_loss_logger.close()
    test_loss_logger.close()


def config():
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    parser = argparse.ArgumentParser(description="GCMC")
    parser.add_argument("--seed", default=123, type=int)
    parser.add_argument(
        "--device",
        default="0",
        type=int,
        help="Running device. E.g `--device 0`, if using cpu, set `--device -1`",
    )
    parser.add_argument("--save_dir", type=str, help="The saving directory")
    parser.add_argument("--save_id", type=int, help="The saving log id")
    parser.add_argument("--silent", action="store_true")
    parser.add_argument(
        "--data_name",
        default="ml-1m",
        type=str,
        help="The dataset name: ml-100k, ml-1m, ml-10m",
    )
    parser.add_argument(
        "--data_test_ratio", type=float, default=0.1
    )  ## for ml-100k the test ration is 0.2
    parser.add_argument("--data_valid_ratio", type=float, default=0.1)
    parser.add_argument("--use_one_hot_fea", action="store_true", default=False)
    parser.add_argument("--model_activation", type=str, default="leaky")
    parser.add_argument("--gcn_dropout", type=float, default=0.7)
    parser.add_argument("--gcn_agg_norm_symm", type=bool, default=True)
    parser.add_argument("--gcn_agg_units", type=int, default=500)
    parser.add_argument("--gcn_agg_accum", type=str, default="sum")
    parser.add_argument("--gcn_out_units", type=int, default=75)
    parser.add_argument("--gen_r_num_basis_func", type=int, default=2)
    parser.add_argument("--train_max_iter", type=int, default=2000)
    parser.add_argument("--train_log_interval", type=int, default=1)
    parser.add_argument("--train_valid_interval", type=int, default=1)
    parser.add_argument("--train_optimizer", type=str, default="adam")
    parser.add_argument("--train_grad_clip", type=float, default=1.0)
    parser.add_argument("--train_lr", type=float, default=0.01)
    parser.add_argument("--train_min_lr", type=float, default=0.001)
    parser.add_argument("--train_lr_decay_factor", type=float, default=0.5)
    parser.add_argument("--train_decay_patience", type=int, default=50)
    parser.add_argument(
        "--train_early_stopping_patience", type=int, default=100
    )
    parser.add_argument("--share_param", default=False, action="store_true")
Zihao Ye's avatar
Zihao Ye committed
301
302

    args = parser.parse_args()
303
304
305
    args.device = (
        th.device(args.device) if args.device >= 0 else th.device("cpu")
    )
Zihao Ye's avatar
Zihao Ye committed
306
307
308

    ### configure save_fir to save all the info
    if args.save_dir is None:
309
310
311
312
313
314
315
        args.save_dir = (
            args.data_name
            + "_"
            + "".join(
                random.choices(string.ascii_uppercase + string.digits, k=2)
            )
        )
Zihao Ye's avatar
Zihao Ye committed
316
317
318
319
320
321
322
323
324
    if args.save_id is None:
        args.save_id = np.random.randint(20)
    args.save_dir = os.path.join("log", args.save_dir)
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    return args


325
if __name__ == "__main__":
Zihao Ye's avatar
Zihao Ye committed
326
327
328
329
330
331
    args = config()
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    if th.cuda.is_available():
        th.cuda.manual_seed_all(args.seed)
    train(args)