Commit 8cbe7d6a authored by VictorSanh's avatar VictorSanh
Browse files

FIX errors in loading eval Dataset in `run_squad_pytorch`

parent 833c3a7a
......@@ -865,10 +865,11 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
#all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
......@@ -877,7 +878,8 @@ def main():
model.eval()
all_results = []
for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment