Commit 38f740a1 authored by VictorSanh's avatar VictorSanh
Browse files

Fix bug writing predictions in run_squad_pytorch

parent ee29871f
...@@ -910,8 +910,8 @@ def main(): ...@@ -910,8 +910,8 @@ def main():
#end_logits = [x.item() for x in end_logits] #end_logits = [x.item() for x in end_logits]
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
for idx, i in enumerate(unique_id): for idx, i in enumerate(unique_id):
s = start_logits[idx] s = [float(x) for x in start_logits[idx]]
e = end_logits[idx] e = [float(x) for x in end_logits[idx]]
all_results.append( all_results.append(
RawResult( RawResult(
unique_id=i, unique_id=i,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment