"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7f9b7b3f0e181e2393fc5b5c70ad2da2a190aa08"
Unverified Commit fd7b6a52 authored by Wissam Antoun's avatar Wissam Antoun Committed by GitHub
Browse files

fixed JSON error in run_qa with fp16 (#9186)

parent 66a14a2f
...@@ -206,7 +206,7 @@ def postprocess_qa_predictions( ...@@ -206,7 +206,7 @@ def postprocess_qa_predictions(
# Make `predictions` JSON-serializable by casting np.float back to float. # Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [ all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()} {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions for pred in predictions
] ]
...@@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search( ...@@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search(
# Make `predictions` JSON-serializable by casting np.float back to float. # Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [ all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()} {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions for pred in predictions
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment