model.py 5.92 KB
Newer Older
1
import argparse
2
3
4
import os
import pickle

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
import evaluation
import layers
9
import numpy as np
10
import sampler as sampler_module
11
12
13
14
import torch
import torch.nn as nn
import torchtext
import tqdm
15
from torch.utils.data import DataLoader
16
17
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
18

19

20
21
22
23
class PinSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
        super().__init__()

24
25
26
        self.proj = layers.LinearProjector(
            full_graph, ntype, textsets, hidden_dims
        )
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        self.sage = layers.SAGENet(hidden_dims, n_layers)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks):
        h_item = self.get_repr(blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_repr(self, blocks):
        h_item = self.proj(blocks[0].srcdata)
        h_item_dst = self.proj(blocks[-1].dstdata)
        return h_item_dst + self.sage(blocks, h_item)

41

42
def train(dataset, args):
43
44
45
46
47
48
49
50
    g = dataset["train-graph"]
    val_matrix = dataset["val-matrix"].tocsr()
    test_matrix = dataset["test-matrix"].tocsr()
    item_texts = dataset["item-texts"]
    user_ntype = dataset["user-type"]
    item_ntype = dataset["item-type"]
    user_to_item_etype = dataset["user-to-item-type"]
    timestamp = dataset["timestamp-edge-column"]
51
52
53
54
55

    device = torch.device(args.device)

    # Assign user and movie IDs and use them as features (to learn an individual trainable
    # embedding for each entity)
56
57
    g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
    g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))
58
59
60
61
62
63
64
65
66
67
68
69
70

    # Prepare torchtext dataset and Vocabulary
    textset = {}
    tokenizer = get_tokenizer(None)

    textlist = []
    batch_first = True

    for i in range(g.num_nodes(item_ntype)):
        for key in item_texts.keys():
            l = tokenizer(item_texts[key][i].lower())
            textlist.append(l)
    for key, field in item_texts.items():
71
72
73
74
75
76
77
78
79
        vocab2 = build_vocab_from_iterator(
            textlist, specials=["<unk>", "<pad>"]
        )
        textset[key] = (
            textlist,
            vocab2,
            vocab2.get_stoi()["<pad>"],
            batch_first,
        )
80
81
82

    # Sampler
    batch_sampler = sampler_module.ItemToItemBatchSampler(
83
84
        g, user_ntype, item_ntype, args.batch_size
    )
85
    neighbor_sampler = sampler_module.NeighborSampler(
86
87
88
89
90
91
92
93
94
95
96
97
        g,
        user_ntype,
        item_ntype,
        args.random_walk_length,
        args.random_walk_restart_prob,
        args.num_random_walks,
        args.num_neighbors,
        args.num_layers,
    )
    collator = sampler_module.PinSAGECollator(
        neighbor_sampler, g, item_ntype, textset
    )
98
99
100
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
101
102
        num_workers=args.num_workers,
    )
103
    dataloader_test = DataLoader(
104
        torch.arange(g.num_nodes(item_ntype)),
105
106
        batch_size=args.batch_size,
        collate_fn=collator.collate_test,
107
108
        num_workers=args.num_workers,
    )
109
110
111
    dataloader_it = iter(dataloader)

    # Model
112
113
114
    model = PinSAGEModel(
        g, item_ntype, textset, args.hidden_dims, args.num_layers
    ).to(device)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Optimizer
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    # For each batch of head-tail-negative triplets...
    for epoch_id in range(args.num_epochs):
        model.train()
        for batch_id in tqdm.trange(args.batches_per_epoch):
            pos_graph, neg_graph, blocks = next(dataloader_it)
            # Copy to GPU
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)

            loss = model(pos_graph, neg_graph, blocks).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()

        # Evaluate
        model.eval()
        with torch.no_grad():
137
138
139
            item_batches = torch.arange(g.num_nodes(item_ntype)).split(
                args.batch_size
            )
140
141
142
143
144
145
146
147
            h_item_batches = []
            for blocks in dataloader_test:
                for i in range(len(blocks)):
                    blocks[i] = blocks[i].to(device)

                h_item_batches.append(model.get_repr(blocks))
            h_item = torch.cat(h_item_batches, 0)

148
149
150
151
            print(
                evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
            )

152

153
if __name__ == "__main__":
154
155
    # Arguments
    parser = argparse.ArgumentParser()
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    parser.add_argument("dataset_path", type=str)
    parser.add_argument("--random-walk-length", type=int, default=2)
    parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
    parser.add_argument("--num-random-walks", type=int, default=10)
    parser.add_argument("--num-neighbors", type=int, default=3)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--hidden-dims", type=int, default=16)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument(
        "--device", type=str, default="cpu"
    )  # can also be "cuda:0"
    parser.add_argument("--num-epochs", type=int, default=1)
    parser.add_argument("--batches-per-epoch", type=int, default=20000)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--lr", type=float, default=3e-5)
    parser.add_argument("-k", type=int, default=10)
172
173
174
    args = parser.parse_args()

    # Load dataset
175
176
    data_info_path = os.path.join(args.dataset_path, "data.pkl")
    with open(data_info_path, "rb") as f:
177
        dataset = pickle.load(f)
178
    train_g_path = os.path.join(args.dataset_path, "train_g.bin")
179
    g_list, _ = dgl.load_graphs(train_g_path)
180
    dataset["train-graph"] = g_list[0]
181
    train(dataset, args)