Commit 1ea0bcf4 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Model] fix link & beam search (#394)

parent ae1806f6
......@@ -176,7 +176,7 @@ def get_dataset(dataset):
('en', 'de'),
train='train.tok.clean.bpe.32000',
valid='newstest2013.tok.bpe.32000',
test='newstest2014.tok.bpe.32000',
test='newstest2014.tok.bpe.32000.ende',
vocab='vocab.bpe.32000')
else:
raise KeyError()
......@@ -4,7 +4,7 @@ import os
from dgl.data.utils import *
_urls = {
'wmt': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt16_en_de.tar.gz',
'wmt': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt14bpe_de_en.zip',
'scripts': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip',
}
......@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name):
if dataset_name == 'multi30k':
os.system('bash scripts/prepare-multi30k.sh')
elif dataset_name == 'wmt14':
download(_urls['wmt'], path='wmt16_en_de.tar.gz')
download(_urls['wmt'], path='wmt14.zip')
os.system('bash scripts/prepare-wmt14.sh')
elif dataset_name == 'copy' or dataset_name == 'tiny_copy':
train_size = 9000
......
......@@ -128,16 +128,15 @@ class Transformer(nn.Module):
"""
return self.generator(g.ndata['x'][nids['dec']])
def infer(self, graph, max_len, eos_id, k):
def infer(self, graph, max_len, eos_id, k, alpha=1.0):
'''
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
max_len: the maximum length of decoding.
eos_id: the index of end-of-sequence symbol.
k: beam size
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
......@@ -187,30 +186,35 @@ class Transformer(nn.Module):
out = self.generator(g.ndata['x'][frontiers])
batch_size = frontiers.shape[0] // k
vocab_size = out.shape[-1]
# Mask output for complete sequence
one_hot = th.zeros(vocab_size).fill_(-1e9).to(device)
one_hot[eos_id] = 0
mask = g.ndata['mask'][frontiers].unsqueeze(-1).float()
out = out * (1 - mask) + one_hot.unsqueeze(0) * mask
if log_prob is None:
log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(k, dim=-1)
eos = th.zeros(batch_size).byte()
eos = th.zeros(batch_size, k).byte()
else:
log_prob, pos = (out.view(batch_size, k, -1) + log_prob.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1)
norm_old = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((4. + step) / 6, alpha)
norm_new = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((5. + step) / 6, alpha)
log_prob, pos = ((out.view(batch_size, k, -1) + (log_prob * norm_old).unsqueeze(-1)) / norm_new.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1)
_y = y.view(batch_size * k, -1)
y = th.zeros_like(_y)
_eos = eos.clone()
for i in range(batch_size):
if not eos[i]:
for j in range(k):
_j = pos[i, j].item() // vocab_size
token = pos[i, j].item() % vocab_size
y[i * k + j, :] = _y[i * k + _j, :]
y[i * k + j, step] = token
if j == 0:
eos[i] = eos[i] | (token == eos_id)
else:
y[i*k:(i+1)*k, :] = _y[i*k:(i+1)*k, :]
for j in range(k):
_j = pos[i, j].item() // vocab_size
token = pos[i, j].item() % vocab_size
y[i*k+j, :] = _y[i*k+_j, :]
y[i*k+j, step] = token
eos[i, j] = _eos[i, _j] | (token == eos_id)
if eos.all():
break
else:
g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, k * max_len).view(-1).to(device)
g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
return y.view(batch_size, k, -1)[:, 0, :].tolist()
def _register_att_map(self, g, enc_ids, dec_ids):
......
......@@ -38,7 +38,7 @@ if __name__ == '__main__':
test_iter = dataset(graph_pool, mode='test', batch_size=args.batch, devices=[device], k=k)
for i, g in enumerate(test_iter):
with th.no_grad():
output = model.infer(g, dataset.MAX_LENGTH, dataset.eos_id, k)
output = model.infer(g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6)
for line in dataset.get_sequence(output):
if args.print:
print(line)
......
......@@ -83,7 +83,7 @@ def main(dev_id, args):
param.data /= ndev
# Optimizer
model_opt = NoamOpt(dim_model, 1, 4000,
model_opt = NoamOpt(dim_model, 0.1, 4000,
T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment