Unverified Commit 2dd652d7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[RAG] Add missing doc and attention_mask to rag (#7382)

* add docs

* add missing docs and attention_mask in fine-tune
parent 7cdd9da5
......@@ -265,6 +265,7 @@ class GenerativeQAModule(BaseTransformer):
start_time = time.time()
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
do_deduplication=False, # rag specific parameter
use_cache=True,
min_length=1,
......
......@@ -831,6 +831,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
:obj:`context_input_ids` has to be provided.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
retriever.
......@@ -1207,6 +1215,14 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
:obj:`context_input_ids` has to be provided.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the
retriever.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment