"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "acd6d2c42f0fa4fade262e8814279748a544b0ce"
Commit 1ea0bcf4 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Model] fix link & beam search (#394)

parent ae1806f6
...@@ -176,7 +176,7 @@ def get_dataset(dataset): ...@@ -176,7 +176,7 @@ def get_dataset(dataset):
('en', 'de'), ('en', 'de'),
train='train.tok.clean.bpe.32000', train='train.tok.clean.bpe.32000',
valid='newstest2013.tok.bpe.32000', valid='newstest2013.tok.bpe.32000',
test='newstest2014.tok.bpe.32000', test='newstest2014.tok.bpe.32000.ende',
vocab='vocab.bpe.32000') vocab='vocab.bpe.32000')
else: else:
raise KeyError() raise KeyError()
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
from dgl.data.utils import * from dgl.data.utils import *
_urls = { _urls = {
'wmt': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt16_en_de.tar.gz', 'wmt': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt14bpe_de_en.zip',
'scripts': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip', 'scripts': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip',
} }
...@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name): ...@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name):
if dataset_name == 'multi30k': if dataset_name == 'multi30k':
os.system('bash scripts/prepare-multi30k.sh') os.system('bash scripts/prepare-multi30k.sh')
elif dataset_name == 'wmt14': elif dataset_name == 'wmt14':
download(_urls['wmt'], path='wmt16_en_de.tar.gz') download(_urls['wmt'], path='wmt14.zip')
os.system('bash scripts/prepare-wmt14.sh') os.system('bash scripts/prepare-wmt14.sh')
elif dataset_name == 'copy' or dataset_name == 'tiny_copy': elif dataset_name == 'copy' or dataset_name == 'tiny_copy':
train_size = 9000 train_size = 9000
......
...@@ -128,16 +128,15 @@ class Transformer(nn.Module): ...@@ -128,16 +128,15 @@ class Transformer(nn.Module):
""" """
return self.generator(g.ndata['x'][nids['dec']]) return self.generator(g.ndata['x'][nids['dec']])
def infer(self, graph, max_len, eos_id, k, alpha=1.0):
def infer(self, graph, max_len, eos_id, k):
''' '''
This function implements Beam Search in DGL, which is required in inference phase. This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args: args:
graph: a `Graph` object defined in `dgl.contrib.transformer.graph`. graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
max_len: the maximum length of decoding. max_len: the maximum length of decoding.
eos_id: the index of end-of-sequence symbol. eos_id: the index of end-of-sequence symbol.
k: beam size k: beam size
return: return:
ret: a list of index array correspond to the input sequence specified by `graph``. ret: a list of index array correspond to the input sequence specified by `graph``.
''' '''
...@@ -187,30 +186,35 @@ class Transformer(nn.Module): ...@@ -187,30 +186,35 @@ class Transformer(nn.Module):
out = self.generator(g.ndata['x'][frontiers]) out = self.generator(g.ndata['x'][frontiers])
batch_size = frontiers.shape[0] // k batch_size = frontiers.shape[0] // k
vocab_size = out.shape[-1] vocab_size = out.shape[-1]
# Mask output for complete sequence
one_hot = th.zeros(vocab_size).fill_(-1e9).to(device)
one_hot[eos_id] = 0
mask = g.ndata['mask'][frontiers].unsqueeze(-1).float()
out = out * (1 - mask) + one_hot.unsqueeze(0) * mask
if log_prob is None: if log_prob is None:
log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(k, dim=-1) log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(k, dim=-1)
eos = th.zeros(batch_size).byte() eos = th.zeros(batch_size, k).byte()
else: else:
log_prob, pos = (out.view(batch_size, k, -1) + log_prob.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1) norm_old = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((4. + step) / 6, alpha)
norm_new = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((5. + step) / 6, alpha)
log_prob, pos = ((out.view(batch_size, k, -1) + (log_prob * norm_old).unsqueeze(-1)) / norm_new.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1)
_y = y.view(batch_size * k, -1) _y = y.view(batch_size * k, -1)
y = th.zeros_like(_y) y = th.zeros_like(_y)
_eos = eos.clone()
for i in range(batch_size): for i in range(batch_size):
if not eos[i]: for j in range(k):
for j in range(k): _j = pos[i, j].item() // vocab_size
_j = pos[i, j].item() // vocab_size token = pos[i, j].item() % vocab_size
token = pos[i, j].item() % vocab_size y[i*k+j, :] = _y[i*k+_j, :]
y[i * k + j, :] = _y[i * k + _j, :] y[i*k+j, step] = token
y[i * k + j, step] = token eos[i, j] = _eos[i, _j] | (token == eos_id)
if j == 0:
eos[i] = eos[i] | (token == eos_id)
else:
y[i*k:(i+1)*k, :] = _y[i*k:(i+1)*k, :]
if eos.all(): if eos.all():
break break
else: else:
g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, k * max_len).view(-1).to(device) g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
return y.view(batch_size, k, -1)[:, 0, :].tolist() return y.view(batch_size, k, -1)[:, 0, :].tolist()
def _register_att_map(self, g, enc_ids, dec_ids): def _register_att_map(self, g, enc_ids, dec_ids):
......
...@@ -38,7 +38,7 @@ if __name__ == '__main__': ...@@ -38,7 +38,7 @@ if __name__ == '__main__':
test_iter = dataset(graph_pool, mode='test', batch_size=args.batch, devices=[device], k=k) test_iter = dataset(graph_pool, mode='test', batch_size=args.batch, devices=[device], k=k)
for i, g in enumerate(test_iter): for i, g in enumerate(test_iter):
with th.no_grad(): with th.no_grad():
output = model.infer(g, dataset.MAX_LENGTH, dataset.eos_id, k) output = model.infer(g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6)
for line in dataset.get_sequence(output): for line in dataset.get_sequence(output):
if args.print: if args.print:
print(line) print(line)
......
...@@ -83,7 +83,7 @@ def main(dev_id, args): ...@@ -83,7 +83,7 @@ def main(dev_id, args):
param.data /= ndev param.data /= ndev
# Optimizer # Optimizer
model_opt = NoamOpt(dim_model, 1, 4000, model_opt = NoamOpt(dim_model, 0.1, 4000,
T.optim.Adam(model.parameters(), lr=1e-3, T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9)) betas=(0.9, 0.98), eps=1e-9))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment