Commit b6938916 authored by thomwolf's avatar thomwolf
Browse files

adding beam search

parent a468870f
...@@ -62,13 +62,18 @@ class PretrainedConfig(object): ...@@ -62,13 +62,18 @@ class PretrainedConfig(object):
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation # Parameters for sequence generation
self.generate_length = kwargs.pop('generate_length', 10) self.generate_max_length = kwargs.pop('generate_max_length', 20)
self.generate_do_sample = kwargs.pop('generate_do_sample', False) self.generate_do_sample = kwargs.pop('generate_do_sample', False)
self.generate_num_beams = kwargs.pop('generate_num_beams', 1) self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
self.generate_temperature = kwargs.pop('generate_temperature', 1.0) self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
self.generate_top_k = kwargs.pop('generate_top_k', 50) self.generate_top_k = kwargs.pop('generate_top_k', 50)
self.generate_top_p = kwargs.pop('generate_top_p', 0.0) self.generate_top_p = kwargs.pop('generate_top_p', 1.0)
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0)
self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0)
self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0)
self.generate_batch_size = kwargs.pop('generate_batch_size', 1)
self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
......
# coding=utf-8 # coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module): ...@@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module):
return model return model
def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None, def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None,
**model_kwargs): bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
""" Generic sequence generator for single-stack models with a LM head. length_penalty=None, **kwargs):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
The method currently supports greedy decoding and sampling. See the Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params: Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,) it as an empty `torch.LongTensor` of shape (1,)
**length**: (`optional`) int **max_length**: (`optional`) int
The length of the sequence to be generated. The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool **do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling. If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
**num_beams**: (`optional`) int
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float **temperature**: (`optional`) float
The value used to module the next token probabilities. The value used to module the next token probabilities.
**k**: (`optional`) int **top_k**: (`optional`) int
The parameter used for k-filtering. The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
**p**: (`optional`) float **top_p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1. The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
**repetition_penalty**: (`optional`) float **repetition_penalty**: (`optional`) float
The parameter for repetition penalty. The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
""" """
if input_ids is None:
input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None: if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
sampler_config = { max_length = max_length if max_length is not None else self.config.generate_max_length
"k": k, do_sample = do_sample if do_sample is not None else self.config.generate_do_sample
"p": p, num_beams = num_beams if num_beams is not None else self.config.generate_num_beams
"do_sample": do_sample, temperature = temperature if temperature is not None else self.config.generate_temperature
"temperature": temperature, top_k = top_k if top_k is not None else self.config.generate_top_k
"repetition_penalty": repetition_penalty, top_p = top_p if top_p is not None else self.config.generate_top_p
} repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id
sampler = Sampler(**sampler_config) pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id
generated_sequence = input_ids eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids
for _ in trange(length): batch_size = batch_size if batch_size is not None else self.config.generate_batch_size
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty
outputs = self(**arguments)
next_tokens_logits = outputs[0][:, -1, :] if input_ids is not None:
next_tokens = sampler.get_one_token( batch_size = input_ids.shape[0] # overriden by the input batch_size
next_tokens_logits, generated_sequence if isinstance(eos_token_ids, int):
) eos_token_ids = [eos_token_ids]
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer."
assert 0 < temperature, "`temperature` should be positive."
assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive."
assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer."
assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer."
assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer."
assert 0 < length_penalty, "`length_penalty` should be strictely positive."
if input_ids is None:
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
else:
assert input_ids.dims() == 2
# current position and vocab size
cur_len = 1
vocab_size = self.config.vocab_size
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# cache compute states
pasts = None # self.prepare_pasts()
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# select next words with scores
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for sent_id in range(batch_size):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
if done[sent_id]:
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
# end of sentence, or next word
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item())
else:
next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
break
# update next beam content
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams
if len(next_sent_beam) == 0:
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (sent_id + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(batch_size):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len = src_len.new(batch_size)
best = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)
# generate target batch
decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index)
for i, hypo in enumerate(best):
decoded[:tgt_len[i] - 1, i] = hypo
decoded[tgt_len[i] - 1, i] = self.eos_index
# sanity check
assert (decoded == self.eos_index).sum() == 2 * batch_size
return decoded, tgt_len
class BeamHypotheses(object):
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
return generated_sequence.squeeze(0) def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): def is_done(self, best_sum_logprobs):
return model_kwargs.update({"input_ids": input_ids}) """
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
class Sampler(object): class Sampler(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment