from collections import defaultdict as ddict import numpy as np import torch from ordered_set import OrderedSet from torch.utils.data import DataLoader, Dataset import dgl class TrainDataset(Dataset): """ Training Dataset class. Parameters ---------- triples: The triples used for training the model num_ent: Number of entities in the knowledge graph lbl_smooth: Label smoothing Returns ------- A training Dataset class instance used by DataLoader """ def __init__(self, triples, num_ent, lbl_smooth): self.triples = triples self.num_ent = num_ent self.lbl_smooth = lbl_smooth self.entities = np.arange(self.num_ent, dtype=np.int32) def __len__(self): return len(self.triples) def __getitem__(self, idx): ele = self.triples[idx] triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"]) trp_label = self.get_label(label) # label smoothing if self.lbl_smooth != 0.0: trp_label = (1.0 - self.lbl_smooth) * trp_label + ( 1.0 / self.num_ent ) return triple, trp_label @staticmethod def collate_fn(data): triples = [] labels = [] for triple, label in data: triples.append(triple) labels.append(label) triple = torch.stack(triples, dim=0) trp_label = torch.stack(labels, dim=0) return triple, trp_label # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 def get_label(self, label): y = np.zeros([self.num_ent], dtype=np.float32) for e2 in label: y[e2] = 1.0 return torch.FloatTensor(y) class TestDataset(Dataset): """ Evaluation Dataset class. Parameters ---------- triples: The triples used for evaluating the model num_ent: Number of entities in the knowledge graph Returns ------- An evaluation Dataset class instance used by DataLoader for model evaluation """ def __init__(self, triples, num_ent): self.triples = triples self.num_ent = num_ent def __len__(self): return len(self.triples) def __getitem__(self, idx): ele = self.triples[idx] triple, label = torch.LongTensor(ele["triple"]), np.int32(ele["label"]) label = self.get_label(label) return triple, label @staticmethod def collate_fn(data): triples = [] labels = [] for triple, label in data: triples.append(triple) labels.append(label) triple = torch.stack(triples, dim=0) label = torch.stack(labels, dim=0) return triple, label # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0 def get_label(self, label): y = np.zeros([self.num_ent], dtype=np.float32) for e2 in label: y[e2] = 1.0 return torch.FloatTensor(y) class Data(object): def __init__(self, dataset, lbl_smooth, num_workers, batch_size): """ Reading in raw triples and converts it into a standard format. Parameters ---------- dataset: The name of the dataset lbl_smooth: Label smoothing num_workers: Number of workers of dataloaders batch_size: Batch size of dataloaders Returns ------- self.ent2id: Entity to unique identifier mapping self.rel2id: Relation to unique identifier mapping self.id2ent: Inverse mapping of self.ent2id self.id2rel: Inverse mapping of self.rel2id self.num_ent: Number of entities in the knowledge graph self.num_rel: Number of relations in the knowledge graph self.g: The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph self.data['train']: Stores the triples corresponding to training dataset self.data['valid']: Stores the triples corresponding to validation dataset self.data['test']: Stores the triples corresponding to test dataset self.data_iter: The dataloader for different data splits """ self.dataset = dataset self.lbl_smooth = lbl_smooth self.num_workers = num_workers self.batch_size = batch_size # read in raw data and get mappings ent_set, rel_set = OrderedSet(), OrderedSet() for split in ["train", "test", "valid"]: for line in open("./{}/{}.txt".format(self.dataset, split)): sub, rel, obj = map(str.lower, line.strip().split("\t")) ent_set.add(sub) rel_set.add(rel) ent_set.add(obj) self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} self.rel2id.update( { rel + "_reverse": idx + len(self.rel2id) for idx, rel in enumerate(rel_set) } ) self.id2ent = {idx: ent for ent, idx in self.ent2id.items()} self.id2rel = {idx: rel for rel, idx in self.rel2id.items()} self.num_ent = len(self.ent2id) self.num_rel = len(self.rel2id) // 2 # read in ids of subjects, relations, and objects for train/test/valid self.data = ddict(list) # stores the triples sr2o = ddict( set ) # The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation) src = [] dst = [] rels = [] inver_src = [] inver_dst = [] inver_rels = [] for split in ["train", "test", "valid"]: for line in open("./{}/{}.txt".format(self.dataset, split)): sub, rel, obj = map(str.lower, line.strip().split("\t")) sub_id, rel_id, obj_id = ( self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], ) self.data[split].append((sub_id, rel_id, obj_id)) if split == "train": sr2o[(sub_id, rel_id)].add(obj_id) sr2o[(obj_id, rel_id + self.num_rel)].add( sub_id ) # append the reversed edges src.append(sub_id) dst.append(obj_id) rels.append(rel_id) inver_src.append(obj_id) inver_dst.append(sub_id) inver_rels.append(rel_id + self.num_rel) # construct dgl graph src = src + inver_src dst = dst + inver_dst rels = rels + inver_rels self.g = dgl.graph((src, dst), num_nodes=self.num_ent) self.g.edata["etype"] = torch.Tensor(rels).long() # identify in and out edges in_edges_mask = [True] * (self.g.num_edges() // 2) + [False] * ( self.g.num_edges() // 2 ) out_edges_mask = [False] * (self.g.num_edges() // 2) + [True] * ( self.g.num_edges() // 2 ) self.g.edata["in_edges_mask"] = torch.Tensor(in_edges_mask) self.g.edata["out_edges_mask"] = torch.Tensor(out_edges_mask) # Prepare train/valid/test data self.data = dict(self.data) self.sr2o = { k: list(v) for k, v in sr2o.items() } # store only the train data for split in ["test", "valid"]: for sub, rel, obj in self.data[split]: sr2o[(sub, rel)].add(obj) sr2o[(obj, rel + self.num_rel)].add(sub) self.sr2o_all = { k: list(v) for k, v in sr2o.items() } # store all the data self.triples = ddict(list) for (sub, rel), obj in self.sr2o.items(): self.triples["train"].append( {"triple": (sub, rel, -1), "label": self.sr2o[(sub, rel)]} ) for split in ["test", "valid"]: for sub, rel, obj in self.data[split]: rel_inv = rel + self.num_rel self.triples["{}_{}".format(split, "tail")].append( { "triple": (sub, rel, obj), "label": self.sr2o_all[(sub, rel)], } ) self.triples["{}_{}".format(split, "head")].append( { "triple": (obj, rel_inv, sub), "label": self.sr2o_all[(obj, rel_inv)], } ) self.triples = dict(self.triples) def get_train_data_loader(split, batch_size, shuffle=True): return DataLoader( TrainDataset( self.triples[split], self.num_ent, self.lbl_smooth ), batch_size=batch_size, shuffle=shuffle, num_workers=max(0, self.num_workers), collate_fn=TrainDataset.collate_fn, ) def get_test_data_loader(split, batch_size, shuffle=True): return DataLoader( TestDataset(self.triples[split], self.num_ent), batch_size=batch_size, shuffle=shuffle, num_workers=max(0, self.num_workers), collate_fn=TestDataset.collate_fn, ) # train/valid/test dataloaders self.data_iter = { "train": get_train_data_loader("train", self.batch_size), "valid_head": get_test_data_loader("valid_head", self.batch_size), "valid_tail": get_test_data_loader("valid_tail", self.batch_size), "test_head": get_test_data_loader("test_head", self.batch_size), "test_tail": get_test_data_loader("test_tail", self.batch_size), }