Unverified Commit f5c45a19 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix Rag example docstring (#7872)

* fix rag examples

* fix token generate example
parent 9f7b2b24
......@@ -740,10 +740,6 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
>>> # 3. Forward to generator
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
>>> # or directly generate
>>> generated = model.generate(input_ids=input_dict["input_ids"])
>>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
"""
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
......@@ -1125,7 +1121,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
>>> # or directly generate
>>> generated = model.generate(input_ids=input_dict["input_ids"])
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
>>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
"""
do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
......@@ -1307,9 +1303,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
else self.config.generator.decoder_start_token_id
)
# batch_size
batch_size = input_ids.shape[0]
# retrieve docs
if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
......@@ -1336,6 +1329,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
1
)
# batch_size
batch_size = context_input_ids.shape[0] // self.config.n_docs
encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment