Commit bb0a5103 authored by VictorSanh's avatar VictorSanh
Browse files

Print for debug run_squad

parent c84315ec
...@@ -818,6 +818,7 @@ def main(): ...@@ -818,6 +818,7 @@ def main():
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_steps)
logger.info("HHHHH Loading data")
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
...@@ -825,14 +826,17 @@ def main(): ...@@ -825,14 +826,17 @@ def main():
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
logger.info("HHHHH Creating dataset")
#train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_data) train_sampler = RandomSampler(train_data)
else: else:
train_sampler = DistributedSampler(train_data) train_sampler = DistributedSampler(train_data)
logger.info("HHHHH Dataloader")
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
logger.info("HHHHH Starting Traing")
model.train() model.train()
for epoch in range(int(args.num_train_epochs)): for epoch in range(int(args.num_train_epochs)):
#for input_ids, input_mask, segment_ids, label_ids in train_dataloader: #for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
...@@ -847,10 +851,14 @@ def main(): ...@@ -847,10 +851,14 @@ def main():
start_positions = start_positions.view(-1, 1) start_positions = start_positions.view(-1, 1)
end_positions = end_positions.view(-1, 1) end_positions = end_positions.view(-1, 1)
logger.info("HHHHH Forward")
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
logger.info("HHHHH Backward")
loss.backward() loss.backward()
logger.info("HHHHH Loading data")
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
logger.info("Done %s steps", global_step)
if args.do_predict: if args.do_predict:
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
...@@ -884,6 +892,7 @@ def main(): ...@@ -884,6 +892,7 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaulating")
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for input_ids, input_mask, segment_ids, example_index in eval_dataloader: for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment