train.py 11.2 KB
Newer Older
1
2
3
4
"""Training GCMC model on the MovieLens data set.

The script loads the full graph to the training device.
"""
Zihao Ye's avatar
Zihao Ye committed
5
6
import argparse
import logging
7
import os
Zihao Ye's avatar
Zihao Ye committed
8
9
import random
import string
10
11
import time

Zihao Ye's avatar
Zihao Ye committed
12
13
14
15
import numpy as np
import torch as th
import torch.nn as nn
from data import MovieLens
16
from model import BiDecoder, GCMCLayer
17
18
19
from utils import (MetricLogger, get_activation, get_optimizer, torch_net_info,
                   torch_total_param_num)

Zihao Ye's avatar
Zihao Ye committed
20
21
22
23
24

class Net(nn.Module):
    def __init__(self, args):
        super(Net, self).__init__()
        self._act = get_activation(args.model_activation)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.encoder = GCMCLayer(
            args.rating_vals,
            args.src_in_units,
            args.dst_in_units,
            args.gcn_agg_units,
            args.gcn_out_units,
            args.gcn_dropout,
            args.gcn_agg_accum,
            agg_act=self._act,
            share_user_item_param=args.share_param,
            device=args.device,
        )
        self.decoder = BiDecoder(
            in_units=args.gcn_out_units,
            num_classes=len(args.rating_vals),
            num_basis=args.gen_r_num_basis_func,
        )
Zihao Ye's avatar
Zihao Ye committed
42
43

    def forward(self, enc_graph, dec_graph, ufeat, ifeat):
44
        user_out, movie_out = self.encoder(enc_graph, ufeat, ifeat)
Zihao Ye's avatar
Zihao Ye committed
45
46
47
        pred_ratings = self.decoder(dec_graph, user_out, movie_out)
        return pred_ratings

48
49

def evaluate(args, net, dataset, segment="valid"):
Zihao Ye's avatar
Zihao Ye committed
50
    possible_rating_values = dataset.possible_rating_values
51
52
53
    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(
        args.device
    )
Zihao Ye's avatar
Zihao Ye committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    if segment == "valid":
        rating_values = dataset.valid_truths
        enc_graph = dataset.valid_enc_graph
        dec_graph = dataset.valid_dec_graph
    elif segment == "test":
        rating_values = dataset.test_truths
        enc_graph = dataset.test_enc_graph
        dec_graph = dataset.test_dec_graph
    else:
        raise NotImplementedError

    # Evaluate RMSE
    net.eval()
    with th.no_grad():
69
70
71
72
73
74
75
        pred_ratings = net(
            enc_graph, dec_graph, dataset.user_feature, dataset.movie_feature
        )
    real_pred_ratings = (
        th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)
    ).sum(dim=1)
    rmse = ((real_pred_ratings - rating_values) ** 2.0).mean().item()
Zihao Ye's avatar
Zihao Ye committed
76
77
78
    rmse = np.sqrt(rmse)
    return rmse

79

Zihao Ye's avatar
Zihao Ye committed
80
81
def train(args):
    print(args)
82
83
84
85
86
87
88
89
    dataset = MovieLens(
        args.data_name,
        args.device,
        use_one_hot_fea=args.use_one_hot_fea,
        symm=args.gcn_agg_norm_symm,
        test_ratio=args.data_test_ratio,
        valid_ratio=args.data_valid_ratio,
    )
Zihao Ye's avatar
Zihao Ye committed
90
91
92
93
94
95
96
97
98
    print("Loading data finished ...\n")

    args.src_in_units = dataset.user_feature_shape[1]
    args.dst_in_units = dataset.movie_feature_shape[1]
    args.rating_vals = dataset.possible_rating_values

    ### build the net
    net = Net(args=args)
    net = net.to(args.device)
99
100
101
    nd_possible_rating_values = th.FloatTensor(
        dataset.possible_rating_values
    ).to(args.device)
Zihao Ye's avatar
Zihao Ye committed
102
103
    rating_loss_net = nn.CrossEntropyLoss()
    learning_rate = args.train_lr
104
105
106
    optimizer = get_optimizer(args.train_optimizer)(
        net.parameters(), lr=learning_rate
    )
Zihao Ye's avatar
Zihao Ye committed
107
108
109
110
111
112
113
    print("Loading network finished ...\n")

    ### perpare training data
    train_gt_labels = dataset.train_labels
    train_gt_ratings = dataset.train_truths

    ### prepare the logger
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    train_loss_logger = MetricLogger(
        ["iter", "loss", "rmse"],
        ["%d", "%.4f", "%.4f"],
        os.path.join(args.save_dir, "train_loss%d.csv" % args.save_id),
    )
    valid_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "valid_loss%d.csv" % args.save_id),
    )
    test_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "test_loss%d.csv" % args.save_id),
    )
Zihao Ye's avatar
Zihao Ye committed
129
130
131
132
133
134
135
136
137

    ### declare the loss information
    best_valid_rmse = np.inf
    no_better_valid = 0
    best_iter = -1
    count_rmse = 0
    count_num = 0
    count_loss = 0

138
139
    dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)
    dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)
140
    dataset.valid_enc_graph = dataset.train_enc_graph
141
142
143
    dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)
    dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)
    dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)
144

Zihao Ye's avatar
Zihao Ye committed
145
146
147
148
149
150
    print("Start training ...")
    dur = []
    for iter_idx in range(1, args.train_max_iter):
        if iter_idx > 3:
            t0 = time.time()
        net.train()
151
152
153
154
155
156
        pred_ratings = net(
            dataset.train_enc_graph,
            dataset.train_dec_graph,
            dataset.user_feature,
            dataset.movie_feature,
        )
Zihao Ye's avatar
Zihao Ye committed
157
158
159
160
161
162
163
164
165
166
167
168
        loss = rating_loss_net(pred_ratings, train_gt_labels).mean()
        count_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
        optimizer.step()

        if iter_idx > 3:
            dur.append(time.time() - t0)

        if iter_idx == 1:
            print("Total #Param of net: %d" % (torch_total_param_num(net)))
169
170
171
172
173
174
175
176
            print(
                torch_net_info(
                    net,
                    save_path=os.path.join(
                        args.save_dir, "net%d.txt" % args.save_id
                    ),
                )
            )
Zihao Ye's avatar
Zihao Ye committed
177

178
179
180
181
        real_pred_ratings = (
            th.softmax(pred_ratings, dim=1)
            * nd_possible_rating_values.view(1, -1)
        ).sum(dim=1)
Zihao Ye's avatar
Zihao Ye committed
182
183
184
185
186
        rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum()
        count_rmse += rmse.item()
        count_num += pred_ratings.shape[0]

        if iter_idx % args.train_log_interval == 0:
187
188
189
190
191
192
193
194
195
196
197
198
199
            train_loss_logger.log(
                iter=iter_idx,
                loss=count_loss / (iter_idx + 1),
                rmse=count_rmse / count_num,
            )
            logging_str = (
                "Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format(
                    iter_idx,
                    count_loss / iter_idx,
                    count_rmse / count_num,
                    np.average(dur),
                )
            )
Zihao Ye's avatar
Zihao Ye committed
200
201
202
203
            count_rmse = 0
            count_num = 0

        if iter_idx % args.train_valid_interval == 0:
204
205
206
207
208
            valid_rmse = evaluate(
                args=args, net=net, dataset=dataset, segment="valid"
            )
            valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
            logging_str += ",\tVal RMSE={:.4f}".format(valid_rmse)
Zihao Ye's avatar
Zihao Ye committed
209
210
211
212
213

            if valid_rmse < best_valid_rmse:
                best_valid_rmse = valid_rmse
                no_better_valid = 0
                best_iter = iter_idx
214
215
216
                test_rmse = evaluate(
                    args=args, net=net, dataset=dataset, segment="test"
                )
Zihao Ye's avatar
Zihao Ye committed
217
218
                best_test_rmse = test_rmse
                test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
219
                logging_str += ", Test RMSE={:.4f}".format(test_rmse)
Zihao Ye's avatar
Zihao Ye committed
220
221
            else:
                no_better_valid += 1
222
223
224
225
226
227
228
                if (
                    no_better_valid > args.train_early_stopping_patience
                    and learning_rate <= args.train_min_lr
                ):
                    logging.info(
                        "Early stopping threshold reached. Stop training."
                    )
Zihao Ye's avatar
Zihao Ye committed
229
230
                    break
                if no_better_valid > args.train_decay_patience:
231
232
233
234
                    new_lr = max(
                        learning_rate * args.train_lr_decay_factor,
                        args.train_min_lr,
                    )
Zihao Ye's avatar
Zihao Ye committed
235
                    if new_lr < learning_rate:
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
236
                        learning_rate = new_lr
Zihao Ye's avatar
Zihao Ye committed
237
238
                        logging.info("\tChange the LR to %g" % new_lr)
                        for p in optimizer.param_groups:
239
                            p["lr"] = learning_rate
Zihao Ye's avatar
Zihao Ye committed
240
                        no_better_valid = 0
241
        if iter_idx % args.train_log_interval == 0:
Zihao Ye's avatar
Zihao Ye committed
242
            print(logging_str)
243
244
245
246
247
    print(
        "Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
            best_iter, best_valid_rmse, best_test_rmse
        )
    )
Zihao Ye's avatar
Zihao Ye committed
248
249
250
251
252
253
    train_loss_logger.close()
    valid_loss_logger.close()
    test_loss_logger.close()


def config():
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    parser = argparse.ArgumentParser(description="GCMC")
    parser.add_argument("--seed", default=123, type=int)
    parser.add_argument(
        "--device",
        default="0",
        type=int,
        help="Running device. E.g `--device 0`, if using cpu, set `--device -1`",
    )
    parser.add_argument("--save_dir", type=str, help="The saving directory")
    parser.add_argument("--save_id", type=int, help="The saving log id")
    parser.add_argument("--silent", action="store_true")
    parser.add_argument(
        "--data_name",
        default="ml-1m",
        type=str,
        help="The dataset name: ml-100k, ml-1m, ml-10m",
    )
    parser.add_argument(
        "--data_test_ratio", type=float, default=0.1
    )  ## for ml-100k the test ration is 0.2
    parser.add_argument("--data_valid_ratio", type=float, default=0.1)
    parser.add_argument("--use_one_hot_fea", action="store_true", default=False)
    parser.add_argument("--model_activation", type=str, default="leaky")
    parser.add_argument("--gcn_dropout", type=float, default=0.7)
    parser.add_argument("--gcn_agg_norm_symm", type=bool, default=True)
    parser.add_argument("--gcn_agg_units", type=int, default=500)
    parser.add_argument("--gcn_agg_accum", type=str, default="sum")
    parser.add_argument("--gcn_out_units", type=int, default=75)
    parser.add_argument("--gen_r_num_basis_func", type=int, default=2)
    parser.add_argument("--train_max_iter", type=int, default=2000)
    parser.add_argument("--train_log_interval", type=int, default=1)
    parser.add_argument("--train_valid_interval", type=int, default=1)
    parser.add_argument("--train_optimizer", type=str, default="adam")
    parser.add_argument("--train_grad_clip", type=float, default=1.0)
    parser.add_argument("--train_lr", type=float, default=0.01)
    parser.add_argument("--train_min_lr", type=float, default=0.001)
    parser.add_argument("--train_lr_decay_factor", type=float, default=0.5)
    parser.add_argument("--train_decay_patience", type=int, default=50)
    parser.add_argument(
        "--train_early_stopping_patience", type=int, default=100
    )
    parser.add_argument("--share_param", default=False, action="store_true")
Zihao Ye's avatar
Zihao Ye committed
296
297

    args = parser.parse_args()
298
299
300
    args.device = (
        th.device(args.device) if args.device >= 0 else th.device("cpu")
    )
Zihao Ye's avatar
Zihao Ye committed
301
302
303

    ### configure save_fir to save all the info
    if args.save_dir is None:
304
305
306
307
308
309
310
        args.save_dir = (
            args.data_name
            + "_"
            + "".join(
                random.choices(string.ascii_uppercase + string.digits, k=2)
            )
        )
Zihao Ye's avatar
Zihao Ye committed
311
312
313
314
315
316
317
318
319
    if args.save_id is None:
        args.save_id = np.random.randint(20)
    args.save_dir = os.path.join("log", args.save_dir)
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    return args


320
if __name__ == "__main__":
Zihao Ye's avatar
Zihao Ye committed
321
322
323
324
325
326
    args = config()
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    if th.cuda.is_available():
        th.cuda.manual_seed_all(args.seed)
    train(args)