Commit 88def24c authored by patrickvonplaten's avatar patrickvonplaten
Browse files

merge conflicts - renamed to previous_token singular

parent 822f725a
......@@ -556,46 +556,81 @@ class PreTrainedModel(nn.Module):
length_penalty=None,
num_return_sequences=None,
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
Adapted in part from `Facebook's XLM beam search code`_.
.. _`Facebook's XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
Parameters:
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**max_length**: (`optional`) int
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
**num_beams**: (`optional`) int
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**top_k**: (`optional`) int
do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Default to greedy sampling.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
top_k: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
**top_p**: (`optional`) float
top_p: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
**bos_token_id**: (`optional`) int
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0.
**eos_token_ids**: (`optional`) int or list of int
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
**length_penalty**: (`optional`) float
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
**num_return_sequences**: (`optional`) int
The number of independantly computed returned sequences for each element in the batch. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3) # generate 3 independent sequences using beam search decoding (5 beams) from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
"""
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
......@@ -623,7 +658,7 @@ class PreTrainedModel(nn.Module):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
# assert temperature >= 0, "`temperature` should be positive."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
......@@ -725,16 +760,16 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_tokens] < 0:
next_token_logits[i, previous_tokens] *= repetition_penalty
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_tokens] /= repetition_penalty
next_token_logits[i, previous_token] /= repetition_penalty
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature > 0 and temperature != 1.0:
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
......@@ -808,16 +843,16 @@ class PreTrainedModel(nn.Module):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_tokens in set(input_ids[i].tolist()):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_tokens] < 0:
scores[i, previous_tokens] *= repetition_penalty
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_tokens] /= repetition_penalty
scores[i, previous_token] /= repetition_penalty
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature > 0 and temperature != 1.0:
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
scores = top_k_top_p_filtering(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment