import pickle import argparse import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchtext import dgl import os import tqdm import layers import sampler as sampler_module import evaluation class PinSAGEModel(nn.Module): def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers): super().__init__() self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims) self.sage = layers.SAGENet(hidden_dims, n_layers) self.scorer = layers.ItemToItemScorer(full_graph, ntype) def forward(self, pos_graph, neg_graph, blocks): h_item = self.get_repr(blocks) pos_score = self.scorer(pos_graph, h_item) neg_score = self.scorer(neg_graph, h_item) return (neg_score - pos_score + 1).clamp(min=0) def get_repr(self, blocks): h_item = self.proj(blocks[0].srcdata) h_item_dst = self.proj(blocks[-1].dstdata) return h_item_dst + self.sage(blocks, h_item) def train(dataset, args): g = dataset['train-graph'] val_matrix = dataset['val-matrix'].tocsr() test_matrix = dataset['test-matrix'].tocsr() item_texts = dataset['item-texts'] user_ntype = dataset['user-type'] item_ntype = dataset['item-type'] user_to_item_etype = dataset['user-to-item-type'] timestamp = dataset['timestamp-edge-column'] device = torch.device(args.device) # Assign user and movie IDs and use them as features (to learn an individual trainable # embedding for each entity) g.nodes[user_ntype].data['id'] = torch.arange(g.number_of_nodes(user_ntype)) g.nodes[item_ntype].data['id'] = torch.arange(g.number_of_nodes(item_ntype)) # Prepare torchtext dataset and vocabulary fields = {} examples = [] for key, texts in item_texts.items(): fields[key] = torchtext.legacy.data.Field(include_lengths=True, lower=True, batch_first=True) for i in range(g.number_of_nodes(item_ntype)): example = torchtext.legacy.data.Example.fromlist( [item_texts[key][i] for key in item_texts.keys()], [(key, fields[key]) for key in item_texts.keys()]) examples.append(example) textset = torchtext.legacy.data.Dataset(examples, fields) for key, field in fields.items(): field.build_vocab(getattr(textset, key)) #field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d') # Sampler batch_sampler = sampler_module.ItemToItemBatchSampler( g, user_ntype, item_ntype, args.batch_size) neighbor_sampler = sampler_module.NeighborSampler( g, user_ntype, item_ntype, args.random_walk_length, args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors, args.num_layers) collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset) dataloader = DataLoader( batch_sampler, collate_fn=collator.collate_train, num_workers=args.num_workers) dataloader_test = DataLoader( torch.arange(g.number_of_nodes(item_ntype)), batch_size=args.batch_size, collate_fn=collator.collate_test, num_workers=args.num_workers) dataloader_it = iter(dataloader) # Model model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device) # Optimizer opt = torch.optim.Adam(model.parameters(), lr=args.lr) # For each batch of head-tail-negative triplets... for epoch_id in range(args.num_epochs): model.train() for batch_id in tqdm.trange(args.batches_per_epoch): pos_graph, neg_graph, blocks = next(dataloader_it) # Copy to GPU for i in range(len(blocks)): blocks[i] = blocks[i].to(device) pos_graph = pos_graph.to(device) neg_graph = neg_graph.to(device) loss = model(pos_graph, neg_graph, blocks).mean() opt.zero_grad() loss.backward() opt.step() # Evaluate model.eval() with torch.no_grad(): item_batches = torch.arange(g.number_of_nodes(item_ntype)).split(args.batch_size) h_item_batches = [] for blocks in dataloader_test: for i in range(len(blocks)): blocks[i] = blocks[i].to(device) h_item_batches.append(model.get_repr(blocks)) h_item = torch.cat(h_item_batches, 0) print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)) if __name__ == '__main__': # Arguments parser = argparse.ArgumentParser() parser.add_argument('dataset_path', type=str) parser.add_argument('--random-walk-length', type=int, default=2) parser.add_argument('--random-walk-restart-prob', type=float, default=0.5) parser.add_argument('--num-random-walks', type=int, default=10) parser.add_argument('--num-neighbors', type=int, default=3) parser.add_argument('--num-layers', type=int, default=2) parser.add_argument('--hidden-dims', type=int, default=16) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0" parser.add_argument('--num-epochs', type=int, default=1) parser.add_argument('--batches-per-epoch', type=int, default=20000) parser.add_argument('--num-workers', type=int, default=0) parser.add_argument('--lr', type=float, default=3e-5) parser.add_argument('-k', type=int, default=10) args = parser.parse_args() # Load dataset data_info_path = os.path.join(args.dataset_path, 'data.pkl') with open(data_info_path, 'rb') as f: dataset = pickle.load(f) train_g_path = os.path.join(args.dataset_path, 'train_g.bin') g_list, _ = dgl.load_graphs(train_g_path) dataset['train-graph'] = g_list[0] train(dataset, args)