train_sampling.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
"""Training GCMC model on the MovieLens data set by mini-batch sampling.

The script loads the full graph in CPU and samples subgraphs for computing
gradients on the training device. The script also supports multi-GPU for
further acceleration.
"""
import argparse
import logging
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
import os, time
10
11
12
import random
import string
import traceback
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14

import dgl
15
16
import numpy as np
import torch as th
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
import torch.multiprocessing as mp
18
import torch.nn as nn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
import tqdm
20
from data import MovieLens
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
21
22
23
24
25
26
27
28
29
30
31
32
from model import BiDecoder, DenseBiDecoder, GCMCLayer
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from utils import (
    get_activation,
    get_optimizer,
    MetricLogger,
    to_etype_name,
    torch_net_info,
    torch_total_param_num,
)

33
34
35
36
37

class Net(nn.Module):
    def __init__(self, args, dev_id):
        super(Net, self).__init__()
        self._act = get_activation(args.model_activation)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
38
39
40
41
42
43
44
45
46
47
48
49
        self.encoder = GCMCLayer(
            args.rating_vals,
            args.src_in_units,
            args.dst_in_units,
            args.gcn_agg_units,
            args.gcn_out_units,
            args.gcn_dropout,
            args.gcn_agg_accum,
            agg_act=self._act,
            share_user_item_param=args.share_param,
            device=dev_id,
        )
50
51
52
53
54
55
56
        if args.mix_cpu_gpu and args.use_one_hot_fea:
            # if use_one_hot_fea, user and movie feature is None
            # W can be extremely large, with mix_cpu_gpu W should be stored in CPU
            self.encoder.partial_to(dev_id)
        else:
            self.encoder.to(dev_id)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57
58
59
60
61
        self.decoder = BiDecoder(
            in_units=args.gcn_out_units,
            num_classes=len(args.rating_vals),
            num_basis=args.gen_r_num_basis_func,
        )
62
63
        self.decoder.to(dev_id)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
64
65
66
    def forward(
        self, compact_g, frontier, ufeat, ifeat, possible_rating_values
    ):
67
        user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
68
        pred_ratings = self.decoder(compact_g, user_out, movie_out)
69
70
        return pred_ratings

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
71

72
73
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
    output_nodes = pair_graph.ndata[dgl.NID]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
74
75
76
77
78
79
80
81
82
83
    head_feat = (
        input_nodes["user"]
        if dataset.user_feature is None
        else dataset.user_feature[input_nodes["user"]]
    )
    tail_feat = (
        input_nodes["movie"]
        if dataset.movie_feature is None
        else dataset.movie_feature[input_nodes["movie"]]
    )
84
85

    for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
87
88
89
90
91
92
93
94
95
96
97
        block.dstnodes["user"].data["ci"] = parent_graph.nodes["user"].data[
            "ci"
        ][block.dstnodes["user"].data[dgl.NID]]
        block.srcnodes["user"].data["cj"] = parent_graph.nodes["user"].data[
            "cj"
        ][block.srcnodes["user"].data[dgl.NID]]
        block.dstnodes["movie"].data["ci"] = parent_graph.nodes["movie"].data[
            "ci"
        ][block.dstnodes["movie"].data[dgl.NID]]
        block.srcnodes["movie"].data["cj"] = parent_graph.nodes["movie"].data[
            "cj"
        ][block.srcnodes["movie"].data[dgl.NID]]
98
99
100

    return head_feat, tail_feat, blocks

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
101

102
def flatten_etypes(pair_graph, dataset, segment):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
104
    n_users = pair_graph.number_of_nodes("user")
    n_movies = pair_graph.number_of_nodes("movie")
105
106
107
108
109
110
    src = []
    dst = []
    labels = []
    ratings = []

    for rating in dataset.possible_rating_values:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
        src_etype, dst_etype = pair_graph.edges(
            order="eid", etype=to_etype_name(rating)
        )
114
115
116
117
118
119
120
121
122
123
        src.append(src_etype)
        dst.append(dst_etype)
        label = np.searchsorted(dataset.possible_rating_values, rating)
        ratings.append(th.LongTensor(np.full_like(src_etype, rating)))
        labels.append(th.LongTensor(np.full_like(src_etype, label)))
    src = th.cat(src)
    dst = th.cat(dst)
    ratings = th.cat(ratings)
    labels = th.cat(labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
124
125
126
127
128
129
    flattened_pair_graph = dgl.heterograph(
        {("user", "rate", "movie"): (src, dst)},
        num_nodes_dict={"user": n_users, "movie": n_movies},
    )
    flattened_pair_graph.edata["rating"] = ratings
    flattened_pair_graph.edata["label"] = labels
130
131
132

    return flattened_pair_graph

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
133
134

def evaluate(args, dev_id, net, dataset, dataloader, segment="valid"):
135
    possible_rating_values = dataset.possible_rating_values
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
137
138
    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(
        dev_id
    )
139
140
141

    real_pred_ratings = []
    true_rel_ratings = []
142
143
    for input_nodes, pair_graph, blocks in dataloader:
        head_feat, tail_feat, blocks = load_subtensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
144
145
146
147
148
149
150
151
            input_nodes,
            pair_graph,
            blocks,
            dataset,
            dataset.valid_enc_graph
            if segment == "valid"
            else dataset.test_enc_graph,
        )
152
        frontier = blocks[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
153
154
155
156
157
        true_relation_ratings = (
            dataset.valid_truths[pair_graph.edata[dgl.EID]]
            if segment == "valid"
            else dataset.test_truths[pair_graph.edata[dgl.EID]]
        )
158

159
        frontier = frontier.to(dev_id)
160
161
        head_feat = head_feat.to(dev_id)
        tail_feat = tail_feat.to(dev_id)
Mufei Li's avatar
Mufei Li committed
162
        pair_graph = pair_graph.to(dev_id)
163
        with th.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
164
165
166
167
168
169
170
171
172
173
174
            pred_ratings = net(
                pair_graph,
                frontier,
                head_feat,
                tail_feat,
                possible_rating_values,
            )
        batch_pred_ratings = (
            th.softmax(pred_ratings, dim=1)
            * nd_possible_rating_values.view(1, -1)
        ).sum(dim=1)
175
176
177
178
179
        real_pred_ratings.append(batch_pred_ratings)
        true_rel_ratings.append(true_relation_ratings)

    real_pred_ratings = th.cat(real_pred_ratings, dim=0)
    true_rel_ratings = th.cat(true_rel_ratings, dim=0).to(dev_id)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
180
    rmse = ((real_pred_ratings - true_rel_ratings) ** 2.0).mean().item()
181
182
183
    rmse = np.sqrt(rmse)
    return rmse

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
184

185
def config():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    parser = argparse.ArgumentParser(description="GCMC")
    parser.add_argument("--seed", default=123, type=int)
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--save_dir", type=str, help="The saving directory")
    parser.add_argument("--save_id", type=int, help="The saving log id")
    parser.add_argument("--silent", action="store_true")
    parser.add_argument(
        "--data_name",
        default="ml-1m",
        type=str,
        help="The dataset name: ml-100k, ml-1m, ml-10m",
    )
    parser.add_argument(
        "--data_test_ratio", type=float, default=0.1
    )  ## for ml-100k the test ration is 0.2
    parser.add_argument("--data_valid_ratio", type=float, default=0.1)
    parser.add_argument("--use_one_hot_fea", action="store_true", default=False)
    parser.add_argument("--model_activation", type=str, default="leaky")
    parser.add_argument("--gcn_dropout", type=float, default=0.7)
    parser.add_argument("--gcn_agg_norm_symm", type=bool, default=True)
    parser.add_argument("--gcn_agg_units", type=int, default=500)
    parser.add_argument("--gcn_agg_accum", type=str, default="sum")
    parser.add_argument("--gcn_out_units", type=int, default=75)
    parser.add_argument("--gen_r_num_basis_func", type=int, default=2)
    parser.add_argument("--train_max_epoch", type=int, default=1000)
    parser.add_argument("--train_log_interval", type=int, default=1)
    parser.add_argument("--train_valid_interval", type=int, default=1)
    parser.add_argument("--train_optimizer", type=str, default="adam")
    parser.add_argument("--train_grad_clip", type=float, default=1.0)
    parser.add_argument("--train_lr", type=float, default=0.01)
    parser.add_argument("--train_min_lr", type=float, default=0.0001)
    parser.add_argument("--train_lr_decay_factor", type=float, default=0.5)
    parser.add_argument("--train_decay_patience", type=int, default=25)
    parser.add_argument("--train_early_stopping_patience", type=int, default=50)
    parser.add_argument("--share_param", default=False, action="store_true")
    parser.add_argument("--mix_cpu_gpu", default=False, action="store_true")
    parser.add_argument("--minibatch_size", type=int, default=20000)
    parser.add_argument("--num_workers_per_gpu", type=int, default=8)
224
225
226
227

    args = parser.parse_args()
    ### configure save_fir to save all the info
    if args.save_dir is None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
228
229
230
231
232
233
234
        args.save_dir = (
            args.data_name
            + "_"
            + "".join(
                random.choices(string.ascii_uppercase + string.digits, k=2)
            )
        )
235
236
237
238
239
240
241
242
    if args.save_id is None:
        args.save_id = np.random.randint(20)
    args.save_dir = os.path.join("log", args.save_dir)
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    return args

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
243

244
245
def run(proc_id, n_gpus, args, devices, dataset):
    dev_id = devices[proc_id]
246
    if n_gpus > 1:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
247
248
249
        dist_init_method = "tcp://{master_ip}:{master_port}".format(
            master_ip="127.0.0.1", master_port="12345"
        )
250
        world_size = n_gpus
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
252
253
254
255
256
        th.distributed.init_process_group(
            backend="nccl",
            init_method=dist_init_method,
            world_size=world_size,
            rank=dev_id,
        )
257
258
259
    if n_gpus > 0:
        th.cuda.set_device(dev_id)

260
261
262
263
    train_labels = dataset.train_labels
    train_truths = dataset.train_truths
    num_edges = train_truths.shape[0]

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
264
265
266
267
    reverse_types = {
        to_etype_name(k): "rev-" + to_etype_name(k)
        for k in dataset.possible_rating_values
    }
268
    reverse_types.update({v: k for k, v in reverse_types.items()})
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
269
270
271
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [None], return_eids=True
    )
272
273
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
274
        dataset.train_enc_graph,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
275
276
277
278
279
280
        {
            to_etype_name(k): th.arange(
                dataset.train_enc_graph.number_of_edges(etype=to_etype_name(k))
            )
            for k in dataset.possible_rating_values
        },
281
        sampler,
282
        use_ddp=n_gpus > 1,
283
284
        batch_size=args.minibatch_size,
        shuffle=True,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
285
286
        drop_last=False,
    )
287
288

    if proc_id == 0:
289
        valid_dataloader = dgl.dataloading.DataLoader(
290
291
292
293
294
295
            dataset.valid_dec_graph,
            th.arange(dataset.valid_dec_graph.number_of_edges()),
            sampler,
            g_sampling=dataset.valid_enc_graph,
            batch_size=args.minibatch_size,
            shuffle=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
296
297
            drop_last=False,
        )
298
        test_dataloader = dgl.dataloading.DataLoader(
299
300
301
302
303
304
            dataset.test_dec_graph,
            th.arange(dataset.test_dec_graph.number_of_edges()),
            sampler,
            g_sampling=dataset.test_enc_graph,
            batch_size=args.minibatch_size,
            shuffle=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
305
306
            drop_last=False,
        )
307

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
308
    nd_possible_rating_values = th.FloatTensor(dataset.possible_rating_values)
309
310
311
312
313
    nd_possible_rating_values = nd_possible_rating_values.to(dev_id)

    net = Net(args=args, dev_id=dev_id)
    net = net.to(dev_id)
    if n_gpus > 1:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
314
315
316
        net = DistributedDataParallel(
            net, device_ids=[dev_id], output_device=dev_id
        )
317
318
    rating_loss_net = nn.CrossEntropyLoss()
    learning_rate = args.train_lr
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
320
321
    optimizer = get_optimizer(args.train_optimizer)(
        net.parameters(), lr=learning_rate
    )
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    print("Loading network finished ...\n")

    ### declare the loss information
    best_valid_rmse = np.inf
    no_better_valid = 0
    best_epoch = -1
    count_rmse = 0
    count_num = 0
    count_loss = 0
    print("Start training ...")
    dur = []
    iter_idx = 1

    for epoch in range(1, args.train_max_epoch):
336
337
        if n_gpus > 1:
            dataloader.set_epoch(epoch)
338
339
340
        if epoch > 1:
            t0 = time.time()
        net.train()
341
342
343
        with tqdm.tqdm(dataloader) as tq:
            for step, (input_nodes, pair_graph, blocks) in enumerate(tq):
                head_feat, tail_feat, blocks = load_subtensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
344
345
346
347
348
349
                    input_nodes,
                    pair_graph,
                    blocks,
                    dataset,
                    dataset.train_enc_graph,
                )
350
                frontier = blocks[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
351
352
353
354
355
                compact_g = flatten_etypes(pair_graph, dataset, "train").to(
                    dev_id
                )
                true_relation_labels = compact_g.edata["label"]
                true_relation_ratings = compact_g.edata["rating"]
356
357
358
359
360

                head_feat = head_feat.to(dev_id)
                tail_feat = tail_feat.to(dev_id)
                frontier = frontier.to(dev_id)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
361
362
363
364
365
366
367
368
369
370
                pred_ratings = net(
                    compact_g,
                    frontier,
                    head_feat,
                    tail_feat,
                    dataset.possible_rating_values,
                )
                loss = rating_loss_net(
                    pred_ratings, true_relation_labels.to(dev_id)
                ).mean()
371
372
373
374
375
376
377
                count_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
                optimizer.step()

                if proc_id == 0 and iter_idx == 1:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
378
379
380
381
382
383
384
385
386
387
388
                    print(
                        "Total #Param of net: %d" % (torch_total_param_num(net))
                    )

                real_pred_ratings = (
                    th.softmax(pred_ratings, dim=1)
                    * nd_possible_rating_values.view(1, -1)
                ).sum(dim=1)
                rmse = (
                    (real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2
                ).sum()
389
390
391
                count_rmse += rmse.item()
                count_num += pred_ratings.shape[0]

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
392
393
394
395
396
397
398
                tq.set_postfix(
                    {
                        "loss": "{:.4f}".format(count_loss / iter_idx),
                        "rmse": "{:.4f}".format(count_rmse / count_num),
                    },
                    refresh=False,
                )
399
400
401

                iter_idx += 1

402
403
404
405
406
407
408
409
        if epoch > 1:
            epoch_time = time.time() - t0
            print("Epoch {} time {}".format(epoch, epoch_time))

        if epoch % args.train_valid_interval == 0:
            if n_gpus > 1:
                th.distributed.barrier()
            if proc_id == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
410
411
412
413
414
415
416
417
418
                valid_rmse = evaluate(
                    args=args,
                    dev_id=dev_id,
                    net=net,
                    dataset=dataset,
                    dataloader=valid_dataloader,
                    segment="valid",
                )
                logging_str = "Val RMSE={:.4f}".format(valid_rmse)
419
420
421
422
423

                if valid_rmse < best_valid_rmse:
                    best_valid_rmse = valid_rmse
                    no_better_valid = 0
                    best_epoch = epoch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
424
425
426
427
428
429
430
431
                    test_rmse = evaluate(
                        args=args,
                        dev_id=dev_id,
                        net=net,
                        dataset=dataset,
                        dataloader=test_dataloader,
                        segment="test",
                    )
432
                    best_test_rmse = test_rmse
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
433
                    logging_str += ", Test RMSE={:.4f}".format(test_rmse)
434
435
                else:
                    no_better_valid += 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
436
437
438
439
440
441
442
                    if (
                        no_better_valid > args.train_early_stopping_patience
                        and learning_rate <= args.train_min_lr
                    ):
                        logging.info(
                            "Early stopping threshold reached. Stop training."
                        )
443
444
                        break
                    if no_better_valid > args.train_decay_patience:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
445
446
447
448
                        new_lr = max(
                            learning_rate * args.train_lr_decay_factor,
                            args.train_min_lr,
                        )
449
450
451
452
                        if new_lr < learning_rate:
                            logging.info("\tChange the LR to %g" % new_lr)
                            learning_rate = new_lr
                            for p in optimizer.param_groups:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
453
                                p["lr"] = learning_rate
454
455
456
457
458
459
                            no_better_valid = 0
                            print("Change the LR to %g" % new_lr)
            # sync on evalution
            if n_gpus > 1:
                th.distributed.barrier()

460
461
        if proc_id == 0:
            print(logging_str)
462
    if proc_id == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
463
464
465
466
467
468
        print(
            "Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
                best_epoch, best_valid_rmse, best_test_rmse
            )
        )

469

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
470
if __name__ == "__main__":
471
472
    args = config()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
473
    devices = list(map(int, args.gpu.split(",")))
474
475
476
477
    n_gpus = len(devices)

    # For GCMC based on sampling, we require node has its own features.
    # Otherwise (node_id is the feature), the model can not scale
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
478
479
480
481
482
483
484
485
486
    dataset = MovieLens(
        args.data_name,
        "cpu",
        mix_cpu_gpu=args.mix_cpu_gpu,
        use_one_hot_fea=args.use_one_hot_fea,
        symm=args.gcn_agg_norm_symm,
        test_ratio=args.data_test_ratio,
        valid_ratio=args.data_valid_ratio,
    )
487
488
489
490
491
492
493
494
    print("Loading data finished ...\n")

    args.src_in_units = dataset.user_feature_shape[1]
    args.dst_in_units = dataset.movie_feature_shape[1]
    args.rating_vals = dataset.possible_rating_values

    # cpu
    if devices[0] == -1:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
495
        run(0, 0, args, ["cpu"], dataset)
496
497
498
499
500
    # gpu
    elif n_gpus == 1:
        run(0, n_gpus, args, devices, dataset)
    # multi gpu
    else:
501
502
        # Create csr/coo/csc formats before launching training processes with multi-gpu.
        # This avoids creating certain formats in each sub-process, which saves momory and CPU.
503
504
        dataset.train_enc_graph.create_formats_()
        dataset.train_dec_graph.create_formats_()
505
        mp.spawn(run, args=(n_gpus, args, devices, dataset), nprocs=n_gpus)