Unverified Commit 28c77ddf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix QA pipeline on Windows (#8947)

parent 72d6c9c6
...@@ -1883,6 +1883,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -1883,6 +1883,8 @@ class QuestionAnsweringPipeline(Pipeline):
with torch.no_grad(): with torch.no_grad():
# Retrieve the score for the context tokens only (removing question tokens) # Retrieve the score for the context tokens only (removing question tokens)
fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()} fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
# On Windows, the default int type in numpy is np.int32 so we get some non-long tensors.
fw_args = {k: v.long() if v.dtype == torch.int32 else v for (k, v) in fw_args.items()}
start, end = self.model(**fw_args)[:2] start, end = self.model(**fw_args)[:2]
start, end = start.cpu().numpy(), end.cpu().numpy() start, end = start.cpu().numpy(), end.cpu().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment