Commit a1126237 authored by thomwolf's avatar thomwolf
Browse files

clean up logits extraction logic

parent 2a97fe22
...@@ -908,7 +908,7 @@ def main(): ...@@ -908,7 +908,7 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaluating") logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
...@@ -916,21 +916,18 @@ def main(): ...@@ -916,21 +916,18 @@ def main():
input_mask = input_mask.to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask) with torch.no_grad():
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] for i, example_index in enumerate(example_indices):
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] start_logits = batch_start_logits[i].detach().cpu().tolist()
for idx, i in enumerate(unique_id): end_logits = batch_end_logits[i].detach().cpu().tolist()
s = [float(x) for x in start_logits[idx]]
e = [float(x) for x in end_logits[idx]] eval_feature = eval_features[example_index.item()]
all_results.append( unique_id = int(eval_feature.unique_id)
RawResult( all_results.append(RawResult(unique_id=unique_id,
unique_id=i, start_logits=start_logits,
start_logits=s, end_logits=end_logits))
end_logits=e
)
)
output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment