"tests/vscode:/vscode.git/clone" did not exist on "87282cb73c8797b8c54b1d5482b84873eedae6c6"
Unverified Commit 5f7a07c0 authored by Derrick Blakely's avatar Derrick Blakely Committed by GitHub
Browse files

use return dict for rag encoder (#9363)

parent ae333d04
......@@ -1437,7 +1437,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
batch_size = context_input_ids.shape[0] // n_docs
encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask)
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
input_ids = torch.full(
(batch_size * num_beams, 1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment