Commit ee29871f authored by VictorSanh's avatar VictorSanh
Browse files

Debug run_squad_pytorch

parent 101eabff
......@@ -909,11 +909,21 @@ def main():
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
#end_logits = [x.item() for x in end_logits]
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
all_results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
for idx, i in enumerate(unique_id):
s = start_logits[idx]
e = end_logits[idx]
all_results.append(
RawResult(
unique_id=i,
start_logits=s,
end_logits=e
)
)
# all_results.append(
# RawResult(
# unique_id=unique_id,
# start_logits=start_logits,
# end_logits=end_logits))
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment