Commit 7816f792 authored by thomwolf's avatar thomwolf
Browse files

clean up distributed training logging in run_squad example

parent 1135f238
......@@ -985,7 +985,7 @@ def main():
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
......@@ -1058,7 +1058,7 @@ def main():
model.eval()
all_results = []
logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment