Commit 8a2be93b authored by thomwolf's avatar thomwolf
Browse files

fix merge

parent 562f8640
......@@ -374,24 +374,6 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_example_index, all_cls_index, all_p_mask)
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
all_is_impossible = torch.tensor([1.0 if f.is_impossible == True else 0.0 for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions,
all_cls_index, all_p_mask, all_is_impossible)
if output_examples:
return dataset, examples, features
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment