"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "821de121e86574504ec648f76ccb924e38125b52"
Commit 833c3a7a authored by VictorSanh's avatar VictorSanh
Browse files

FIX errors in loading Dataset in `run_squad_pytorch`

parent 72d69a4e
...@@ -818,9 +818,12 @@ def main(): ...@@ -818,9 +818,12 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) #all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_data) train_sampler = RandomSampler(train_data)
else: else:
...@@ -829,13 +832,16 @@ def main(): ...@@ -829,13 +832,16 @@ def main():
model.train() model.train()
for epoch in range(int(args.num_train_epochs)): for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: #for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) #label_ids = label_ids.to(device)
start_positions = start_positions.to(device)
end_positions = start_positions.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment