Commit 92e0ad5a authored by thomwolf's avatar thomwolf
Browse files

no numpy

parent 4e6edc32
...@@ -456,7 +456,7 @@ def main(): ...@@ -456,7 +456,7 @@ def main():
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
elif output_mode == "regression": elif output_mode == "regression":
preds = np.squeeze(preds) preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, out_label_ids.numpy()) result = compute_metrics(task_name, preds, out_label_ids)
if args.local_rank != -1: if args.local_rank != -1:
# Average over distributed nodes if needed # Average over distributed nodes if needed
...@@ -533,7 +533,7 @@ def main(): ...@@ -533,7 +533,7 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
preds = preds[0] preds = preds[0]
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, out_label_ids.numpy()) result = compute_metrics(task_name, preds, out_label_ids)
if args.local_rank != -1: if args.local_rank != -1:
# Average over distributed nodes if needed # Average over distributed nodes if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment