import torch import dgl import numpy as np import json import pickle import random NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2} def write_txt(batch, seqs, w_file, args): # converting the prediction to real text. ret = [] for b, seq in enumerate(seqs): txt = [] for token in seq: # copy the entity if token>=len(args.text_vocab): ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)] ent_text = filter(lambda x:x!='', ent_text) txt.extend(ent_text) else: if int(token) not in [args.text_vocab(x) for x in ['', '', '']]: txt.append(args.text_vocab(int(token))) if int(token) == args.text_vocab(''): break w_file.write(' '.join([str(x) for x in txt])+'\n') ret.append([' '.join([str(x) for x in txt])]) return ret def replace_ent(x, ent, V): # replace the entity mask = x>=V if mask.sum()==0: return x nz = mask.nonzero() fill_ent = ent[nz, x[mask]-V] x = x.masked_scatter(mask, fill_ent) return x def len2mask(lens, device): max_len = max(lens) mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len) mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1) return mask def pad(var_len_list, out_type='list', flatten=False): if flatten: lens = [len(x) for x in var_len_list] var_len_list = sum(var_len_list, []) max_len = max([len(x) for x in var_len_list]) if out_type=='list': if flatten: return [x+['']*(max_len-len(x)) for x in var_len_list], lens else: return [x+['']*(max_len-len(x)) for x in var_len_list] if out_type=='tensor': if flatten: return torch.stack([torch.cat([x, \ torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens else: return torch.stack([torch.cat([x, \ torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0) class Vocab(object): def __init__(self, max_vocab=2**31, min_freq=-1, sp=['', '', '', '']): self.i2s = [] self.s2i = {} self.wf = {} self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp def __len__(self): return len(self.i2s) def __str__(self): return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10]) def update(self, token): if isinstance(token, list): for t in token: self.update(t) else: self.wf[token] = self.wf.get(token, 0) + 1 def build(self): self.i2s.extend(self.sp) sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True) for k,v in sort_kv: if len(self.i2s)=self.min_freq and k not in self.sp: self.i2s.append(k) self.s2i.update(list(zip(self.i2s, range(len(self.i2s))))) def __call__(self, x): if isinstance(x, int): return self.i2s[x] else: return self.s2i.get(x, self.s2i['']) def save(self, fname): pass def load(self, fname): pass def at_least(x): # handling the illegal data if len(x) == 0: return [''] else: return x class Example(object): def __init__(self, title, ent_text, ent_type, rel, text): # one object corresponds to a data sample self.raw_title = title.split() self.raw_ent_text = [at_least(x.split()) for x in ent_text] assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text) self.raw_ent_type = ent_type.split() # .. <> self.raw_rel = [] for r in rel: rel_list = r.split() for i in range(len(rel_list)): if i>0 and i0: graph.add_edges(*list(map(list, zip(*adj_edges)))) return graph def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): if hasattr(self, '_cached_tensor'): return self._cached_tensor else: title_data = [''] + self.raw_title + [''] title = [title_vocab(x) for x in title_data] ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text] ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], []) rel = [rel_vocab(x) for x in rel_data] text_data = [''] + self.raw_text + [''] text = [text_vocab(x) for x in text_data] tgt_text = [] # the input text and decoding target are different since the consideration of the copy mechanism. for i, str1 in enumerate(text_data): if str1[0]=='<' and str1[-1]=='>' and '_' in str1: a, b = str1[1:-1].split('_') text[i] = text_vocab('<'+a+'>') tgt_text.append(len(text_vocab)+int(b)) else: tgt_text.append(text[i]) self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \ 'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \ 'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text} return self._cached_tensor def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): ent_vocab.update(self.raw_ent_type) ent_text_vocab.update(self.raw_ent_text) title_vocab.update(self.raw_title) rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel]) text_vocab.update(self.raw_ent_type) text_vocab.update(self.raw_text) class BucketSampler(torch.utils.data.Sampler): def __init__(self, data_source, batch_size=32, bucket=3): self.data_source = data_source self.bucket = bucket self.batch_size = batch_size def __iter__(self): # the magic number comes from the author's code perm = torch.randperm(len(self.data_source)) lens = torch.Tensor([len(x) for x in self.data_source]) lens = lens[perm] t1 = [] t2 = [] t3 = [] for i, l in enumerate(lens): if (l<100): t1.append(perm[i]) elif (l>100 and l<220): t2.append(perm[i]) else: t3.append(perm[i]) datas = [t1,t2,t3] random.shuffle(datas) idxs = sum(datas, []) batch = [] for idx in idxs: batch.append(idx) mlen = max([0]+[lens[x] for x in batch]) if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32: yield batch batch = [] if len(batch) > 0: yield batch def __len__(self): return (len(self.data_source)+self.batch_size-1)//self.batch_size class GWdataset(torch.utils.data.Dataset): def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None): super(GWdataset, self).__init__() self.exs = exs self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \ ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device def __iter__(self): return iter(self.exs) def __getitem__(self, index): return self.exs[index] def __len__(self): return len(self.exs) def batch_fn(self, batch_ex): batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \ [], [], [], [], [], [], [] batch_raw_ent_text = [] for ex in batch_ex: ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab) batch_title.append(ex_data['title']) batch_ent_text.append(ex_data['ent_text']) batch_ent_type.append(ex_data['ent_type']) batch_rel.append(ex_data['rel']) batch_text.append(ex_data['text']) batch_tgt_text.append(ex_data['tgt_text']) batch_graph.append(ex_data['graph']) batch_raw_ent_text.append(ex_data['raw_ent_text']) batch_title = pad(batch_title, out_type='tensor') batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True) batch_ent_type = pad(batch_ent_type, out_type='tensor') batch_rel = pad(batch_rel, out_type='tensor') batch_text = pad(batch_text, out_type='tensor') batch_tgt_text = pad(batch_tgt_text, out_type='tensor') batch_graph = dgl.batch(batch_graph) batch_graph.to(self.device) return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \ 'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \ 'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text} def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'): # min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class. # sep : not support now # joint_vocab : not support now ent_vocab = Vocab(sp=['', '']) title_vocab = Vocab(min_freq=5) rel_vocab = Vocab(sp=['', '']) text_vocab = Vocab(min_freq=5) ent_text_vocab = Vocab(sp=['', '']) datasets = [] for fname in fnames: exs = [] json_datas = json.loads(open(fname).read()) for json_data in json_datas: # construct one data example ex = Example.from_json(json_data) if fname == fnames[0]: # only training set ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab) exs.append(ex) datasets.append(exs) ent_vocab.build() rel_vocab.build() text_vocab.build() ent_text_vocab.build() title_vocab.build() datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets] with open(save, 'wb') as f: pickle.dump(datasets, f) return datasets if __name__ == '__main__' : ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json']) print(ds[0].exs[0]) print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab))