import dgl from dgl.data import AsLinkPredDataset import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import numpy as np {{ data_import_code }} {{ edge_model_code }} {{ node_model_code}} class Model(nn.Module): def __init__(self, node_model, edge_model, neg_sampler, eval_batch_size): super().__init__() self.node_model = node_model self.edge_model = edge_model self.neg_sampler = neg_sampler self.eval_batch_size = eval_batch_size def inference(self, g, x, edges): src, dst = edges h = self.node_model(g, x) eid_dataloader = DataLoader( range( src.shape[-1]), batch_size=self.eval_batch_size) score_list = [] for eids in eid_dataloader: score = self.edge_model(h[src[eids]], h[dst[eids]]) score_list.append(score) return torch.cat(score_list, dim=0) def calc_hitsk(y_pred_pos, y_pred_neg, k): kth_score_in_negative_edges = torch.topk(y_pred_neg.flatten(), k)[0][-1] hitsK = (y_pred_pos > kth_score_in_negative_edges).float().mean() return hitsK.item() def train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss_fcn): train_g = dataset.train_graph train_g = train_g.to(device) node_feat = train_g.ndata['feat'] train_src, train_dst = train_g.edges() for epoch in range(pipeline_cfg['num_epochs']): model.train() eid_dataloader = DataLoader(range(train_g.num_edges()), batch_size = pipeline_cfg["train_batch_size"], shuffle=True) for eids in eid_dataloader: h = model.node_model(train_g, node_feat) eids = eids.to(device) src, dst = train_src[eids], train_dst[eids] pos_score = model.edge_model(h[src], h[dst]) neg_src, neg_dst = model.neg_sampler(train_g, eids) neg_score = model.edge_model(h[neg_src], h[neg_dst]) loss = loss_fcn(torch.cat([pos_score, neg_score]), torch.cat( [torch.ones_like(pos_score), torch.zeros_like(neg_score)])) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() with torch.no_grad(): model.eval() val_neg_edges = dataset.val_edges[1] val_neg_score = model.inference(train_g, node_feat, val_neg_edges) train_hits = calc_hitsk(pos_score, val_neg_score, k=50) print("Epoch {:05d} | Loss {:.4f} | Train Hits@50 {:.4f}".format(epoch, loss, train_hits)) if epoch != 0 and epoch % pipeline_cfg['eval_period'] == 0: with torch.no_grad(): model.eval() val_pos_edge, val_neg_edges = dataset.val_edges pos_result = model.inference(train_g, node_feat, val_pos_edge) neg_result = model.inference(train_g, node_feat, val_neg_edges) val_hits = calc_hitsk(pos_result, neg_result, k=50) print("Epoch {:05d} | Val Hits@50 {:.4f}".format(epoch, val_hits)) with torch.no_grad(): model.eval() test_pos_edge, test_neg_edges = dataset.test_edges pos_result = model.inference(train_g, node_feat, test_pos_edge) neg_result = model.inference(train_g, node_feat, test_neg_edges) test_hits = calc_hitsk(pos_result, neg_result, k=50) print("Test Hits@50 {:.4f}".format(test_hits)) return test_hits def main(): {{user_cfg_str}} device = cfg['device'] pipeline_cfg = cfg['general_pipeline'] dataset = AsLinkPredDataset({{ data_initialize_code }}) if 'feat' not in dataset[0].ndata: assert cfg["node_model"]["embed_size"] > 0, "Need to specify embed size if graph doesn't have feat in ndata" cfg["node_model"]["data_info"] = { "in_size": cfg["node_model"]['embed_size'] if cfg["node_model"]['embed_size'] > 0 else dataset[0].ndata['feat'].shape[1], "out_size": pipeline_cfg['hidden_size'], "num_nodes": dataset[0].num_nodes() } cfg["edge_model"]["data_info"] = { "in_size": pipeline_cfg['hidden_size'], "out_size": 1 # output each edge score } node_model = {{node_model_class_name}}(**cfg["node_model"]) edge_model = {{edge_model_class_name}}(**cfg["edge_model"]) neg_sampler = dgl.dataloading.negative_sampler.{{ neg_sampler_name }}(**cfg["neg_sampler"]) model = Model(node_model, edge_model, neg_sampler, pipeline_cfg["eval_batch_size"]) model = model.to(device) params = model.parameters() loss = torch.nn.{{ loss }}() optimizer = torch.optim.Adam(params, **pipeline_cfg["optimizer"]) test_hits = train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss) return test_hits if __name__ == '__main__': all_acc = [] num_runs = {{ user_cfg.general_pipeline.num_runs }} for run in range(num_runs): print(f'Run experiment #{run}') test_acc = main() print("Test Hits@50 {:.4f}".format(test_acc)) all_acc.append(test_acc) avg_acc = np.round(np.mean(all_acc), 6) std_acc = np.round(np.std(all_acc), 6) print(f'Test Hits@50 across {num_runs} runs: {avg_acc} ± {std_acc}')