Commit 88def24c authored by patrickvonplaten's avatar patrickvonplaten
Browse files

merge conflicts - renamed to previous_token singular

parent 822f725a
...@@ -556,46 +556,81 @@ class PreTrainedModel(nn.Module): ...@@ -556,46 +556,81 @@ class PreTrainedModel(nn.Module):
length_penalty=None, length_penalty=None,
num_return_sequences=None, num_return_sequences=None,
): ):
""" Sequence generator for models with a LM head. r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search. and beam-search.
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM Adapted in part from `Facebook's XLM beam search code`_.
.. _`Facebook's XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
Parameters:
Params: input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,) it as an empty `torch.LongTensor` of shape `(1,)`.
**max_length**: (`optional`) int
max_length: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20. The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling. do_sample: (`optional`) bool
**num_beams**: (`optional`) int If set to `False` greedy decoding is used. Otherwise sampling is used. Default to greedy sampling.
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float num_beams: (`optional`) int
The value used to module the next token probabilities. Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
**top_k**: (`optional`) int
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
top_k: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
**top_p**: (`optional`) float
top_p: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1. repetition_penalty: (`optional`) float
**bos_token_id**: (`optional`) int The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0. Beginning of sentence token if no prompt is provided. Default to 0.
**eos_token_ids**: (`optional`) int or list of int
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0. End of sequence token or list of tokens to stop the generation. Default to 0.
**length_penalty**: (`optional`) float length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1. Exponential penalty to the length. Default to 1.
**num_return_sequences**: (`optional`) int
The number of independantly computed returned sequences for each element in the batch. Default to 1. num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3) # generate 3 independent sequences using beam search decoding (5 beams) from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
""" """
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None: if self.get_output_embeddings() is None:
raise AttributeError( raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head." "You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)" "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
) )
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
...@@ -623,7 +658,7 @@ class PreTrainedModel(nn.Module): ...@@ -623,7 +658,7 @@ class PreTrainedModel(nn.Module):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
# assert temperature >= 0, "`temperature` should be positive." assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
...@@ -725,16 +760,16 @@ class PreTrainedModel(nn.Module): ...@@ -725,16 +760,16 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size): for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()): for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_tokens] < 0: if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_tokens] *= repetition_penalty next_token_logits[i, previous_token] *= repetition_penalty
else: else:
next_token_logits[i, previous_tokens] /= repetition_penalty next_token_logits[i, previous_token] /= repetition_penalty
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature > 0 and temperature != 1.0: if temperature != 1.0:
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
...@@ -808,16 +843,16 @@ class PreTrainedModel(nn.Module): ...@@ -808,16 +843,16 @@ class PreTrainedModel(nn.Module):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size * num_beams): for i in range(batch_size * num_beams):
for previous_tokens in set(input_ids[i].tolist()): for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_tokens] < 0: if scores[i, previous_token] < 0:
scores[i, previous_tokens] *= repetition_penalty scores[i, previous_token] *= repetition_penalty
else: else:
scores[i, previous_tokens] /= repetition_penalty scores[i, previous_token] /= repetition_penalty
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature > 0 and temperature != 1.0: if temperature != 1.0:
scores = scores / temperature scores = scores / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
scores = top_k_top_p_filtering( scores = top_k_top_p_filtering(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment