Unverified Commit 4362ee29 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct (#13304)

parent 4046e66e
...@@ -988,6 +988,9 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -988,6 +988,9 @@ class RagModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
if torch_device == "cuda":
rag_token.half()
input_dict = tokenizer( input_dict = tokenizer(
self.test_data_questions, self.test_data_questions,
return_tensors="pt", return_tensors="pt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment