Unverified Commit 3e0c62b6 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[RAG] fix generate (#10094)



* fix rag generate and tests

* put back adjust_logits_during_generation

* tests are okay
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 226973a9
...@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None, repetition_penalty=None,
bad_words_ids=None, bad_words_ids=None,
num_return_sequences=None, num_return_sequences=None,
...@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
order to encourage the model to produce longer sequences. order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once. If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[int]`, `optional`): bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
...@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
pre_processor = self._get_logits_processor( pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment