Commit 7816f792 authored by thomwolf's avatar thomwolf
Browse files

clean up distributed training logging in run_squad example

parent 1135f238
...@@ -985,7 +985,7 @@ def main(): ...@@ -985,7 +985,7 @@ def main():
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
if n_gpu == 1: if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch input_ids, input_mask, segment_ids, start_positions, end_positions = batch
...@@ -1058,7 +1058,7 @@ def main(): ...@@ -1058,7 +1058,7 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaluating") logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment