"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "52040517b8abc55cdb4ba2f2549164a91acb44cc"
Unverified Commit ba973342 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1553 from WilliamTambellini/timeSquadInference

Add speed log to examples/run_squad.py
parents 237fad33 0919389d
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
import os import os
import random import random
import glob import glob
import timeit
import numpy as np import numpy as np
import torch import torch
...@@ -221,6 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -221,6 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Num examples = %d", len(dataset)) logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
all_results = [] all_results = []
start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"): for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
...@@ -253,6 +255,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -253,6 +255,9 @@ def evaluate(args, model, tokenizer, prefix=""):
end_logits = to_list(outputs[1][i])) end_logits = to_list(outputs[1][i]))
all_results.append(result) all_results.append(result)
evalTime = timeit.default_timer() - start_time
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
# Compute predictions # Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment