Unverified Commit 3e0c62b6 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[RAG] fix generate (#10094)



* fix rag generate and tests

* put back adjust_logits_during_generation

* tests are okay
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 226973a9
......@@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None,
num_return_sequences=None,
......@@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
......@@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment