"vscode:/vscode.git/clone" did not exist on "e5694f91c0afbf3b7aa7ffda32cb8170cad18fc1"
Unverified Commit 2dd652d7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[RAG] Add missing doc and attention_mask to rag (#7382)

* add docs

* add missing docs and attention_mask in fine-tune
parent 7cdd9da5
...@@ -265,6 +265,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -265,6 +265,7 @@ class GenerativeQAModule(BaseTransformer):
start_time = time.time() start_time = time.time()
generated_ids = self.model.generate( generated_ids = self.model.generate(
batch["input_ids"], batch["input_ids"],
attention_mask=batch["attention_mask"],
do_deduplication=False, # rag specific parameter do_deduplication=False, # rag specific parameter
use_cache=True, use_cache=True,
min_length=1, min_length=1,
......
...@@ -831,6 +831,14 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -831,6 +831,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
:obj:`context_input_ids` has to be provided. :obj:`context_input_ids` has to be provided.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
Input IDs post-processed from the retrieved documents and the question encoder input_ids by the Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
retriever. retriever.
...@@ -1207,6 +1215,14 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1207,6 +1215,14 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then
:obj:`context_input_ids` has to be provided. :obj:`context_input_ids` has to be provided.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the
retriever. retriever.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment