Commit 8cbe7d6a authored by VictorSanh's avatar VictorSanh
Browse files

FIX errors in loading eval Dataset in `run_squad_pytorch`

parent 833c3a7a
...@@ -865,10 +865,11 @@ def main(): ...@@ -865,10 +865,11 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) #all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index) #eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1: if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else: else:
...@@ -877,7 +878,8 @@ def main(): ...@@ -877,7 +878,8 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment