Unverified Commit f5bad031 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use generators tqdm progressbars (#6696)

parent a99d09c6
...@@ -641,8 +641,8 @@ class Trainer: ...@@ -641,8 +641,8 @@ class Trainer:
logging_loss = 0.0 logging_loss = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in train_iterator: for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
...@@ -650,19 +650,21 @@ class Trainer: ...@@ -650,19 +650,21 @@ class Trainer:
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device self.args.device
) )
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm) epoch_iterator = parallel_loader
else: else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm) epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1 steps_trained_in_current_epoch -= 1
epoch_pbar.update(1)
continue continue
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
...@@ -745,11 +747,12 @@ class Trainer: ...@@ -745,11 +747,12 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
epoch_iterator.close()
break break
epoch_pbar.close()
train_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
train_iterator.close()
break break
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -761,6 +764,7 @@ class Trainer: ...@@ -761,6 +764,7 @@ class Trainer:
"configured. Check your training configuration if this is unexpected." "configured. Check your training configuration if this is unexpected."
) )
train_pbar.close()
if self.tb_writer: if self.tb_writer:
self.tb_writer.close() self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"): if self.args.past_index and hasattr(self, "_past"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment