Commit fff3dd95 authored by Guo Qipeng's avatar Guo Qipeng Committed by Zihao Ye
Browse files

[Example] add graphwriter pytorch example (#1068)

* upd

* fig edgebatch edges

* add test

* trigger

* add graphwriter pytorch example

* fix line break in graphwriter README

* upd

* fix
parent 35653ddd
......@@ -28,6 +28,7 @@ A summary of part of the model accuracy and training speed with the Pytorch back
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.3(BLEU) | 14.31(BLEU) | [1970s (PyTorch)](https://github.com/rikdz/GraphWriter) | 1192s | 1.65x |
With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance,
with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)),
......
......@@ -21,3 +21,4 @@ Here is a summary of the model accuracy and training speed. Our testbed is Amazo
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.31(BLEU) | 14.3(BLEU) | 1970s | 1192s | 1.65x |
# GraphWriter-DGL
In this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter).
## Dependencies
- PyTorch >= 1.2
- tqdm
- pycoco (only for testing)
- multi-bleu.perl and other scripts from mosesdecoder (only for testing)
## Usage
```
# download data
sh prepare_data.sh
# training
sh run.sh
# testing
sh test.sh
```
## Result on AGENDA
| |BLEU|METEOR| training time per epoch|
|-|-|-|-|
|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s|
|DGL implementation|14.31+-0.34|19.74+-0.69| 1192s|
We use the author's code for the speed test, and our testbed is V100 GPU.
| |BLEU| detok BLEU| METEOR |
|-|-|-|-|
|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36|
|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49|
|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52|
|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49|
|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69|
|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78|
We repeat the experiment five times.
### Examples
We also provide the output of our implementation on test set together with the reference text.
- [GraphWriter's output](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_pred.txt)
- [Reference text](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_gold.txt)
import torch
from modules import MSA, BiLSTM, GraphTrans
from utlis import *
from torch import nn
import dgl
class GraphWriter(nn.Module):
def __init__(self, args):
super(GraphWriter, self).__init__()
self.args = args
if args.title:
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0)
self.title_enc = BiLSTM(args, enc_type='title')
self.title_attn = MSA(args)
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0)
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0)
if args.title:
nn.init.xavier_normal_(self.title_emb.weight)
nn.init.xavier_normal_(self.ent_emb.weight)
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0)
nn.init.xavier_normal_(self.rel_emb.weight)
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
self.ent_enc = BiLSTM(args, enc_type='entity')
self.graph_enc = GraphTrans(args)
self.ent_attn = MSA(args)
self.copy_attn = MSA(args, mode='copy')
self.copy_fc = nn.Linear(args.dec_ninp, 1)
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))
def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask):
title_enc = None
if self.args.title:
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask)
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len'])
rel_emb = self.rel_emb(batch['rel'])
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph'])
return g_ent, g_root, title_enc, ent_enc
def forward(self, batch, beam_size=-1):
ent_mask = len2mask(batch['ent_len'], self.args.device)
ent_text_mask = batch['ent_text']==0
rel_mask = batch['rel']==0 # 0 means the <PAD>
title_mask = batch['title']==0
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask)
_h, _c = g_root, g_root.clone().detach()
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
if beam_size<1:
# training
outs = []
tar_inp = self.tar_emb(batch['text'].transpose(0,1))
for t, xt in enumerate(tar_inp):
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
outs.append(torch.cat([_h, ctx], 1))
outs = torch.stack(outs, 1)
copy_gate = torch.sigmoid(self.copy_fc(outs))
EPSI = 1e-6
# copy
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1)
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1)
pred = torch.cat([pred_v, pred_c], -1)
return pred
else:
if beam_size==1:
# greedy
device = g_ent.device
B = g_ent.shape[0]
ent_type = batch['ent_type'].view(B, -1)
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1)
for t in range(self.args.beam_max_len):
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab))
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B,-1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred[:, self.args.text_vocab(ban_item)] = -1e8
_, word = pred.max(-1)
seq = torch.cat([seq, word.unsqueeze(1)], 1)
return seq
else:
# beam search
device = g_ent.device
B = g_ent.shape[0]
BSZ = B * beam_size
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
if self.args.title:
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1)
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1)
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1)
beam_best = torch.zeros(B).to(device) - 1e9
beam_best_seq = [None] * B
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1)
beam_score = torch.zeros(B, beam_size).to(device)
done_flag = torch.zeros(B, beam_size)
for t in range(self.args.beam_max_len):
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab))
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred[:, :, self.args.text_vocab(ban_item)] = -1e8
if t==self.args.beam_max_len-1: # force ending
tt = pred[:, :, self.args.text_vocab('<EOS>')]
pred = pred*0-1e8
pred[:, :, self.args.text_vocab('<EOS>')] = tt
cum_score = beam_score.view(B,beam_size,1) + pred
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size
score, word = score.view(B,-1), word.view(B,-1)
eos_idx = self.args.text_vocab('<EOS>')
if beam_seq.size(2)==1:
new_idx = torch.arange(beam_size).to(word)
new_idx = new_idx[None,:].repeat(B,1)
else:
_, new_idx = score.topk(dim=-1, k=beam_size)
new_src, new_score, new_word, new_done = [], [], [], []
LP = beam_seq.size(2) ** self.args.lp
for i in range(B):
for j in range(beam_size):
tmp_score = score[i][new_idx[i][j]]
tmp_word = word[i][new_idx[i][j]]
src_idx = new_idx[i][j]//beam_size
new_src.append(src_idx)
if tmp_word == eos_idx:
new_score.append(-1e8)
else:
new_score.append(tmp_score)
new_word.append(tmp_word)
if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]:
beam_best[i] = tmp_score/LP
beam_best_seq[i] = beam_seq[i][src_idx]
if tmp_word == eos_idx:
new_done.append(1)
else:
new_done.append(done_flag[i][src_idx])
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score)
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq)
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device)
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag)
beam_score = new_score
done_flag = new_done
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src]
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
return beam_best_seq
import torch
import math
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax
from utlis import *
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
class MSA(nn.Module):
# multi-head self-attention, three modes
# the first is the copy, determining which entity should be copied.
# the second is the normal attention with two sequence inputs
# the third is the attention but with one token and a sequence. (gather, attentive pooling)
def __init__(self, args, mode='normal'):
super(MSA, self).__init__()
if mode=='copy':
nhead, head_dim = 1, args.nhid
qninp, kninp = args.dec_ninp, args.nhid
if mode=='normal':
nhead, head_dim = args.nhead, args.head_dim
qninp, kninp = args.nhid, args.nhid
self.attn_drop = nn.Dropout(0.1)
self.WQ = nn.Linear(qninp, nhead*head_dim, bias=True if mode=='copy' else False)
if mode!='copy':
self.WK = nn.Linear(kninp, nhead*head_dim, bias=False)
self.WV = nn.Linear(kninp, nhead*head_dim, bias=False)
self.args, self.nhead, self.head_dim, self.mode = args, nhead, head_dim, mode
def forward(self, inp1, inp2, mask=None):
B, L2, H = inp2.shape
NH, HD = self.nhead, self.head_dim
if self.mode=='copy':
q, k, v = self.WQ(inp1), inp2, inp2
else:
q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
L1 = 1 if inp1.ndim==2 else inp1.shape[1]
if self.mode!='copy':
q = q / math.sqrt(H)
q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
pre_attn = torch.matmul(q,k)
if mask is not None:
pre_attn = pre_attn.masked_fill(mask[:,None,None,:], -1e8)
if self.mode=='copy':
return pre_attn.squeeze(1)
else:
alpha = self.attn_drop(torch.softmax(pre_attn, -1))
attn = torch.matmul(alpha, v).permute(0, 2, 1, 3).contiguous().view(B,L1,NH*HD)
ret = attn
if inp1.ndim==2:
return ret.squeeze(1)
else:
return ret
class BiLSTM(nn.Module):
# for entity encoding or the title encoding
def __init__(self, args, enc_type='title'):
super(BiLSTM, self).__init__()
self.enc_type = enc_type
self.drop = nn.Dropout(args.emb_drop)
self.bilstm = nn.LSTM(args.nhid, args.nhid//2, bidirectional=True, \
num_layers=args.enc_lstm_layers, batch_first=True)
def forward(self, inp, mask, ent_len=None):
inp = self.drop(inp)
lens = (mask==0).sum(-1).long().tolist()
pad_seq = pack_padded_sequence(inp, lens, batch_first=True, enforce_sorted=False)
y, (_h, _c) = self.bilstm(pad_seq)
if self.enc_type=='title':
y = pad_packed_sequence(y, batch_first=True)[0]
return y
if self.enc_type=='entity':
_h = _h.transpose(0,1).contiguous()
_h = _h[:,-2:].view(_h.size(0), -1) # two directions of the top-layer
ret = pad(_h.split(ent_len), out_type='tensor')
return ret
class GAT(nn.Module):
# a graph attention network with dot-product attention
def __init__(self,
in_feats,
out_feats,
num_heads,
ffn_drop=0.,
attn_drop=0.,
trans=True):
super(GAT, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self.q_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.k_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.v_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.attn_drop = nn.Dropout(0.1)
self.ln1 = nn.LayerNorm(in_feats)
self.ln2 = nn.LayerNorm(in_feats)
if trans:
self.FFN = nn.Sequential(
nn.Linear(in_feats, 4*in_feats),
nn.PReLU(4*in_feats),
nn.Linear(4*in_feats, in_feats),
nn.Dropout(0.1),
)
# a strange FFN, see the author's code
self._trans = trans
def forward(self, graph, feat):
graph = graph.local_var()
feat_c = feat.clone().detach().requires_grad_(False)
q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c)
q = q.view(-1, self._num_heads, self._out_feats)
k = k.view(-1, self._num_heads, self._out_feats)
v = v.view(-1, self._num_heads, self._out_feats)
graph.ndata.update({'ft': v, 'el': k, 'er': q}) # k,q instead of q,k, the edge_softmax is applied on incoming edges
# compute edge attention
graph.apply_edges(fn.u_dot_v('el', 'er', 'e'))
e = graph.edata.pop('e') / math.sqrt(self._out_feats * self._num_heads)
graph.edata['a'] = edge_softmax(graph, e).unsqueeze(-1)
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft2'))
rst = graph.ndata['ft2']
# residual
rst = rst.view(feat.shape) + feat
if self._trans:
rst = self.ln1(rst)
rst = self.ln1(rst+self.FFN(rst))
# use the same layer norm, see the author's code
return rst
class GraphTrans(nn.Module):
def __init__(self,args):
super().__init__()
self.args = args
if args.graph_enc == "gat":
# we only support gtrans, don't use this one
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, trans=False) for _ in range(args.prop)]) #untested
else:
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, ffn_drop=args.drop, trans=True) for _ in range(args.prop)])
self.prop = args.prop
def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
device = ent.device
ent_mask = (ent_mask==0) # reverse mask
rel_mask = (rel_mask==0)
init_h = []
for i in range(graphs.batch_size):
init_h.append(ent[i][ent_mask[i]])
init_h.append(rel[i][rel_mask[i]])
init_h = torch.cat(init_h, 0)
feats = init_h
for i in range(self.prop):
feats = self.gat[i](graphs, feats)
g_root = feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['root']).to(device))
g_ent = pad(feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['entity']).to(device)).split(ent_len), out_type='tensor')
return g_ent, g_root
import torch
import argparse
def fill_config(args):
# dirty work
args.device = torch.device(args.gpu)
args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2
args.fnames = [args.train_file, args.valid_file, args.test_file]
return args
def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
# dirty work
args.ent_vocab = ent_vocab
args.rel_vocab = rel_vocab
args.text_vocab = text_vocab
args.ent_text_vocab = ent_text_vocab
args.title_vocab = title_vocab
return args
def get_args():
args = argparse.ArgumentParser(description='Graph Writer in DGL')
args.add_argument('--nhid', default=500, type=int, help='hidden size')
args.add_argument('--nhead', default=4, type=int, help='number of heads')
args.add_argument('--head_dim', default=125, type=int, help='head dim')
args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
args.add_argument('--prop', default=6, type=int, help='number of layers of gnn')
args.add_argument('--title', action='store_true', help='use title input')
args.add_argument('--test', action='store_true', help='inference mode')
args.add_argument('--batch_size', default=32, type=int, help='batch_size')
args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy')
args.add_argument('--epoch', default=20, type=int, help='training epoch')
args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text')
args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm')
args.add_argument('--lr', default=1e-1, type=float, help='learning rate')
#args.add_argument('--lr_decay', default=1e-8, type=float, help='')
args.add_argument('--clip', default=1, type=float, help='gradient clip')
args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout')
args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout')
args.add_argument('--drop', default=0.1, type=float, help='dropout')
args.add_argument('--lp', default=1.0, type=float, help='length penalty')
args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now')
args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file')
args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file')
args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file')
args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset')
args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model')
args.add_argument('--gpu', default=0, type=int, help='gpu mode')
args = args.parse_args()
args = fill_config(args)
return args
wget https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/AGENDA.tar.gz
mkdir data
tar -C data/ -xvzf AGENDA.tar.gz
nohup env CUDA_VISIBLE_DEVICES=0 python -u train.py --prop 6 --save_model tmp_model.pt --title > train_1.log 2>&1 &
#nohup env CUDA_VISIBLE_DEVICES=2 python -u train.py --prop 6 --save_model tmp_model1.pt --title > train_2.log 2>&1 &
#nohup env CUDA_VISIBLE_DEVICES=3 python -u train.py --prop 6 --save_model tmp_model2.pt --title > train_3.log 2>&1 &
#nohup env CUDA_VISIBLE_DEVICES=4 python -u train.py --prop 6 --save_model tmp_model3.pt --title > train_4.log 2>&1 &
#nohup env CUDA_VISIBLE_DEVICES=5 python -u train.py --prop 2 --save_model tmp_model4.pt --title > train_5.log 2>&1 &
#nohup env CUDA_VISIBLE_DEVICES=6 python -u train.py --prop 2 --save_model tmp_model5.pt --title > train_6.log 2>&1 &
env CUDA_VISIBLE_DEVICES=0 python -u train.py --save_model tmp_model.ptbest --test --title --lp 1.0 --beam_size 1
if [ ! detokenizer.perl ]; then
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/tokenizer/detokenizer.perl
fi
if [ ! multi-bleu.perl ]; then
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/generic/multi-bleu.perl
fi
perl detokenizer.perl -l en < tmp_gold.txt > tmp_gold.txt.a
perl detokenizer.perl -l en < tmp_pred.txt > tmp_pred.txt.a
perl multi-bleu.perl tmp_gold.txt < tmp_pred.txt
perl multi-bleu-detok.perl tmp_gold.txt.a < tmp_pred.txt.a
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import time
from tqdm import tqdm
from graphwriter import *
from utlis import *
from opts import *
import os
import sys
sys.path.append('./pycocoevalcap')
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
def train_one_epoch(model, dataloader, optimizer, args, epoch):
model.train()
tloss = 0.
tcnt = 0.
st_time = time.time()
with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq:
for batch in tq:
pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
loss = nll_loss
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
loss = loss.item()
if loss!=loss:
raise ValueError('NaN appear')
tloss += loss * len(batch['tgt_text'])
tcnt += len(batch['tgt_text'])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0)
torch.save(model, args.save_model+str(epoch%100))
val_loss = 2**31
def eval_it(model, dataloader, args, epoch):
global val_loss
model.eval()
tloss = 0.
tcnt = 0.
st_time = time.time()
with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq:
for batch in tq:
with torch.no_grad():
pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
loss = nll_loss
loss = loss.item()
tloss += loss * len(batch['tgt_text'])
tcnt += len(batch['tgt_text'])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time)
if tloss/tcnt < val_loss:
print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt)
torch.save(model, args.save_model+'best')
val_loss = tloss/tcnt
def test(model, dataloader, args):
scorer = Bleu(4)
m_scorer = Meteor()
r_scorer = Rouge()
hyp = []
ref = []
model.eval()
gold_file = open('tmp_gold.txt', 'w')
pred_file = open('tmp_pred.txt', 'w')
with tqdm(dataloader, desc='Test ', mininterval=1) as tq:
for batch in tq:
with torch.no_grad():
seq = model(batch, beam_size=args.beam_size)
r = write_txt(batch, batch['tgt_text'], gold_file, args)
h = write_txt(batch, seq, pred_file, args)
hyp.extend(h)
ref.extend(r)
hyp = dict(zip(range(len(hyp)), hyp))
ref = dict(zip(range(len(ref)), ref))
print(hyp[0], ref[0])
print('BLEU INP', len(hyp), len(ref))
print('BLEU', scorer.compute_score(ref, hyp)[0])
print('METEOR', m_scorer.compute_score(ref, hyp)[0])
print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0])
gold_file.close()
pred_file.close()
def main(args):
if os.path.exists(args.save_dataset):
train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb'))
else:
train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset)
args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \
collate_fn=train_dataset.batch_fn)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \
shuffle=False, collate_fn=train_dataset.batch_fn)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \
shuffle=False, collate_fn=train_dataset.batch_fn)
model = GraphWriter(args)
model.to(args.device)
if args.test:
model = torch.load(args.save_model)
model.args = args
print(model)
test(model, test_dataloader, args)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
print(model)
for epoch in range(args.epoch):
train_one_epoch(model, train_dataloader, optimizer, args, epoch)
eval_it(model, valid_dataloader, args, epoch)
if __name__ == '__main__':
args = get_args()
main(args)
import torch
import dgl
import numpy as np
import json
import pickle
import random
NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2}
def write_txt(batch, seqs, w_file, args):
# converting the prediction to real text.
ret = []
for b, seq in enumerate(seqs):
txt = []
for token in seq:
# copy the entity
if token>=len(args.text_vocab):
ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)]
ent_text = filter(lambda x:x!='<PAD>', ent_text)
txt.extend(ent_text)
else:
if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]:
txt.append(args.text_vocab(int(token)))
if int(token) == args.text_vocab('<EOS>'):
break
w_file.write(' '.join([str(x) for x in txt])+'\n')
ret.append([' '.join([str(x) for x in txt])])
return ret
def replace_ent(x, ent, V):
# replace the entity
mask = x>=V
if mask.sum()==0:
return x
nz = mask.nonzero()
fill_ent = ent[nz, x[mask]-V]
x = x.masked_scatter(mask, fill_ent)
return x
def len2mask(lens, device):
max_len = max(lens)
mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len)
mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
return mask
def pad(var_len_list, out_type='list', flatten=False):
if flatten:
lens = [len(x) for x in var_len_list]
var_len_list = sum(var_len_list, [])
max_len = max([len(x) for x in var_len_list])
if out_type=='list':
if flatten:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list], lens
else:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list]
if out_type=='tensor':
if flatten:
return torch.stack([torch.cat([x, \
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens
else:
return torch.stack([torch.cat([x, \
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0)
class Vocab(object):
def __init__(self, max_vocab=2**31, min_freq=-1, sp=['<PAD>', '<BOS>', '<EOS>', '<UNK>']):
self.i2s = []
self.s2i = {}
self.wf = {}
self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp
def __len__(self):
return len(self.i2s)
def __str__(self):
return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10])
def update(self, token):
if isinstance(token, list):
for t in token:
self.update(t)
else:
self.wf[token] = self.wf.get(token, 0) + 1
def build(self):
self.i2s.extend(self.sp)
sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True)
for k,v in sort_kv:
if len(self.i2s)<self.max_vocab and v>=self.min_freq and k not in self.sp:
self.i2s.append(k)
self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))
def __call__(self, x):
if isinstance(x, int):
return self.i2s[x]
else:
return self.s2i.get(x, self.s2i['<UNK>'])
def save(self, fname):
pass
def load(self, fname):
pass
def at_least(x):
# handling the illegal data
if len(x) == 0:
return ['<UNK>']
else:
return x
class Example(object):
def __init__(self, title, ent_text, ent_type, rel, text):
# one object corresponds to a data sample
self.raw_title = title.split()
self.raw_ent_text = [at_least(x.split()) for x in ent_text]
assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text)
self.raw_ent_type = ent_type.split() # <method> .. <>
self.raw_rel = []
for r in rel:
rel_list = r.split()
for i in range(len(rel_list)):
if i>0 and i<len(rel_list)-1 and rel_list[i-1]=='--' and rel_list[i]!=rel_list[i].lower() and rel_list[i+1]=='--':
self.raw_rel.append([rel_list[:i-1], rel_list[i-1]+rel_list[i]+rel_list[i+1], rel_list[i+2:]])
break
self.raw_text = text.split()
self.graph = self.build_graph()
def __str__(self):
return '\n'.join([str(k)+':\t'+str(v) for k, v in self.__dict__.items()])
def __len__(self):
return len(self.raw_text)
@staticmethod
def from_json(json_data):
return Example(json_data['title'], json_data['entities'], json_data['types'], json_data['relations'], json_data['abstract'])
def build_graph(self):
graph = dgl.DGLGraph()
ent_len = len(self.raw_ent_text)
rel_len = len(self.raw_rel) # treat the repeated relation as different nodes, refer to the author's code
graph.add_nodes(ent_len, {'type': torch.ones(ent_len) * NODE_TYPE['entity']})
graph.add_nodes(1, {'type': torch.ones(1) * NODE_TYPE['root']})
graph.add_nodes(rel_len*2, {'type': torch.ones(rel_len*2) * NODE_TYPE['relation']})
graph.add_edges(ent_len, torch.arange(ent_len))
graph.add_edges(torch.arange(ent_len), ent_len)
graph.add_edges(torch.arange(ent_len+1+rel_len*2), torch.arange(ent_len+1+rel_len*2))
adj_edges = []
for i, r in enumerate(self.raw_rel):
assert len(r)==3, str(r)
st, rt, ed = r
st_ent, ed_ent = self.raw_ent_text.index(st), self.raw_ent_text.index(ed)
# according to the edge_softmax operator, we need to reverse the graph
adj_edges.append([ent_len+1+2*i, st_ent])
adj_edges.append([ed_ent, ent_len+1+2*i])
adj_edges.append([ent_len+1+2*i+1, ed_ent])
adj_edges.append([st_ent, ent_len+1+2*i+1])
if len(adj_edges)>0:
graph.add_edges(*list(map(list, zip(*adj_edges))))
return graph
def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
if hasattr(self, '_cached_tensor'):
return self._cached_tensor
else:
title_data = ['<BOS>'] + self.raw_title + ['<EOS>']
title = [title_vocab(x) for x in title_data]
ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text]
ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference
rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], [])
rel = [rel_vocab(x) for x in rel_data]
text_data = ['<BOS>'] + self.raw_text + ['<EOS>']
text = [text_vocab(x) for x in text_data]
tgt_text = []
# the input text and decoding target are different since the consideration of the copy mechanism.
for i, str1 in enumerate(text_data):
if str1[0]=='<' and str1[-1]=='>' and '_' in str1:
a, b = str1[1:-1].split('_')
text[i] = text_vocab('<'+a+'>')
tgt_text.append(len(text_vocab)+int(b))
else:
tgt_text.append(text[i])
self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \
'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \
'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text}
return self._cached_tensor
def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
ent_vocab.update(self.raw_ent_type)
ent_text_vocab.update(self.raw_ent_text)
title_vocab.update(self.raw_title)
rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel])
text_vocab.update(self.raw_ent_type)
text_vocab.update(self.raw_text)
class BucketSampler(torch.utils.data.Sampler):
def __init__(self, data_source, batch_size=32, bucket=3):
self.data_source = data_source
self.bucket = bucket
self.batch_size = batch_size
def __iter__(self):
# the magic number comes from the author's code
perm = torch.randperm(len(self.data_source))
lens = torch.Tensor([len(x) for x in self.data_source])
lens = lens[perm]
t1 = []
t2 = []
t3 = []
for i, l in enumerate(lens):
if (l<100):
t1.append(perm[i])
elif (l>100 and l<220):
t2.append(perm[i])
else:
t3.append(perm[i])
datas = [t1,t2,t3]
random.shuffle(datas)
idxs = sum(datas, [])
batch = []
for idx in idxs:
batch.append(idx)
mlen = max([0]+[lens[x] for x in batch])
if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
yield batch
batch = []
if len(batch) > 0:
yield batch
def __len__(self):
return (len(self.data_source)+self.batch_size-1)//self.batch_size
class GWdataset(torch.utils.data.Dataset):
def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None):
super(GWdataset, self).__init__()
self.exs = exs
self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \
ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device
def __iter__(self):
return iter(self.exs)
def __getitem__(self, index):
return self.exs[index]
def __len__(self):
return len(self.exs)
def batch_fn(self, batch_ex):
batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \
[], [], [], [], [], [], []
batch_raw_ent_text = []
for ex in batch_ex:
ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab)
batch_title.append(ex_data['title'])
batch_ent_text.append(ex_data['ent_text'])
batch_ent_type.append(ex_data['ent_type'])
batch_rel.append(ex_data['rel'])
batch_text.append(ex_data['text'])
batch_tgt_text.append(ex_data['tgt_text'])
batch_graph.append(ex_data['graph'])
batch_raw_ent_text.append(ex_data['raw_ent_text'])
batch_title = pad(batch_title, out_type='tensor')
batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True)
batch_ent_type = pad(batch_ent_type, out_type='tensor')
batch_rel = pad(batch_rel, out_type='tensor')
batch_text = pad(batch_text, out_type='tensor')
batch_tgt_text = pad(batch_tgt_text, out_type='tensor')
batch_graph = dgl.batch(batch_graph)
batch_graph.to(self.device)
return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \
'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \
'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text}
def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'):
# min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.
# sep : not support now
# joint_vocab : not support now
ent_vocab = Vocab(sp=['<PAD>', '<UNK>'])
title_vocab = Vocab(min_freq=5)
rel_vocab = Vocab(sp=['<PAD>', '<UNK>'])
text_vocab = Vocab(min_freq=5)
ent_text_vocab = Vocab(sp=['<PAD>', '<UNK>'])
datasets = []
for fname in fnames:
exs = []
json_datas = json.loads(open(fname).read())
for json_data in json_datas:
# construct one data example
ex = Example.from_json(json_data)
if fname == fnames[0]: # only training set
ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab)
exs.append(ex)
datasets.append(exs)
ent_vocab.build()
rel_vocab.build()
text_vocab.build()
ent_text_vocab.build()
title_vocab.build()
datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets]
with open(save, 'wb') as f:
pickle.dump(datasets, f)
return datasets
if __name__ == '__main__' :
ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json'])
print(ds[0].exs[0])
print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment