model_sparse.py 6.2 KB
Newer Older
1
import argparse
2
3
4
5
6
import os
import pickle

import evaluation
import layers
7
import numpy as np
8
import sampler as sampler_module
9
10
11
12
import torch
import torch.nn as nn
import torchtext
import tqdm
13
from torch.utils.data import DataLoader
14
15
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
16

17
18
19
import dgl


20
21
22
23
class PinSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
        super().__init__()

24
25
26
        self.proj = layers.LinearProjector(
            full_graph, ntype, textsets, hidden_dims
        )
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        self.sage = layers.SAGENet(hidden_dims, n_layers)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks, item_emb):
        h_item = self.get_repr(blocks, item_emb)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_repr(self, blocks, item_emb):
        # project features
        h_item = self.proj(blocks[0].srcdata)
        h_item_dst = self.proj(blocks[-1].dstdata)

        # add to the item embedding itself
        h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item)
43
44
45
        h_item_dst = h_item_dst + item_emb(
            blocks[-1].dstdata[dgl.NID].cpu()
        ).to(h_item_dst)
46
47
48

        return h_item_dst + self.sage(blocks, h_item)

49

50
def train(dataset, args):
51
52
53
54
55
56
57
58
    g = dataset["train-graph"]
    val_matrix = dataset["val-matrix"].tocsr()
    test_matrix = dataset["test-matrix"].tocsr()
    item_texts = dataset["item-texts"]
    user_ntype = dataset["user-type"]
    item_ntype = dataset["item-type"]
    user_to_item_etype = dataset["user-to-item-type"]
    timestamp = dataset["timestamp-edge-column"]
59
60
61
62

    device = torch.device(args.device)

    # Prepare torchtext dataset and vocabulary
63
64
65
66
67
68
69
70
71
72
73
    textset = {}
    tokenizer = get_tokenizer(None)

    textlist = []
    batch_first = True

    for i in range(g.num_nodes(item_ntype)):
        for key in item_texts.keys():
            l = tokenizer(item_texts[key][i].lower())
            textlist.append(l)
    for key, field in item_texts.items():
74
75
76
77
78
79
80
81
82
        vocab2 = build_vocab_from_iterator(
            textlist, specials=["<unk>", "<pad>"]
        )
        textset[key] = (
            textlist,
            vocab2,
            vocab2.get_stoi()["<pad>"],
            batch_first,
        )
83
84
85

    # Sampler
    batch_sampler = sampler_module.ItemToItemBatchSampler(
86
87
        g, user_ntype, item_ntype, args.batch_size
    )
88
    neighbor_sampler = sampler_module.NeighborSampler(
89
90
91
92
93
94
95
96
97
98
99
100
        g,
        user_ntype,
        item_ntype,
        args.random_walk_length,
        args.random_walk_restart_prob,
        args.num_random_walks,
        args.num_neighbors,
        args.num_layers,
    )
    collator = sampler_module.PinSAGECollator(
        neighbor_sampler, g, item_ntype, textset
    )
101
102
103
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
104
105
        num_workers=args.num_workers,
    )
106
    dataloader_test = DataLoader(
107
        torch.arange(g.num_nodes(item_ntype)),
108
109
        batch_size=args.batch_size,
        collate_fn=collator.collate_test,
110
111
        num_workers=args.num_workers,
    )
112
113
114
    dataloader_it = iter(dataloader)

    # Model
115
116
117
118
119
120
    model = PinSAGEModel(
        g, item_ntype, textset, args.hidden_dims, args.num_layers
    ).to(device)
    item_emb = nn.Embedding(
        g.num_nodes(item_ntype), args.hidden_dims, sparse=True
    )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    # Optimizer
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr)

    # For each batch of head-tail-negative triplets...
    for epoch_id in range(args.num_epochs):
        model.train()
        for batch_id in tqdm.trange(args.batches_per_epoch):
            pos_graph, neg_graph, blocks = next(dataloader_it)
            # Copy to GPU
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to(device)
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)

            loss = model(pos_graph, neg_graph, blocks, item_emb).mean()
            opt.zero_grad()
            opt_emb.zero_grad()
            loss.backward()
            opt.step()
            opt_emb.step()

        # Evaluate
        model.eval()
        with torch.no_grad():
146
147
148
            item_batches = torch.arange(g.num_nodes(item_ntype)).split(
                args.batch_size
            )
149
150
151
152
153
154
155
156
            h_item_batches = []
            for blocks in tqdm.tqdm(dataloader_test):
                for i in range(len(blocks)):
                    blocks[i] = blocks[i].to(device)

                h_item_batches.append(model.get_repr(blocks, item_emb))
            h_item = torch.cat(h_item_batches, 0)

157
158
159
160
            print(
                evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
            )

161

162
if __name__ == "__main__":
163
164
    # Arguments
    parser = argparse.ArgumentParser()
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    parser.add_argument("dataset_path", type=str)
    parser.add_argument("--random-walk-length", type=int, default=2)
    parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
    parser.add_argument("--num-random-walks", type=int, default=10)
    parser.add_argument("--num-neighbors", type=int, default=3)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--hidden-dims", type=int, default=16)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument(
        "--device", type=str, default="cpu"
    )  # can also be "cuda:0"
    parser.add_argument("--num-epochs", type=int, default=1)
    parser.add_argument("--batches-per-epoch", type=int, default=20000)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--lr", type=float, default=3e-5)
    parser.add_argument("-k", type=int, default=10)
181
182
183
    args = parser.parse_args()

    # Load dataset
184
185
    data_info_path = os.path.join(args.dataset_path, "data.pkl")
    with open(data_info_path, "rb") as f:
186
        dataset = pickle.load(f)
187
    train_g_path = os.path.join(args.dataset_path, "train_g.bin")
188
    g_list, _ = dgl.load_graphs(train_g_path)
189
    dataset["train-graph"] = g_list[0]
190
    train(dataset, args)