Unverified Commit 9d2ce253 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

TPU hangs when saving optimizer/scheduler (#4467)

* TPU hangs when saving optimizer/scheduler

* Style

* ParallelLoader is not a DataLoader

* Style

* Addressing @julien-c's comments
parent 49296533
...@@ -242,9 +242,6 @@ class Trainer: ...@@ -242,9 +242,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader return data_loader
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
...@@ -269,9 +266,6 @@ class Trainer: ...@@ -269,9 +266,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader return data_loader
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
...@@ -292,9 +286,6 @@ class Trainer: ...@@ -292,9 +286,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader return data_loader
def get_optimizers( def get_optimizers(
...@@ -351,15 +342,11 @@ class Trainer: ...@@ -351,15 +342,11 @@ class Trainer:
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps) self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
) )
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
Helper to get num of examples from a DataLoader, by accessing its Dataset. Helper to get num of examples from a DataLoader, by accessing its Dataset.
""" """
if is_tpu_available(): return len(dataloader.dataset)
assert isinstance(dataloader, pl.PerDeviceLoader)
return len(dataloader._loader._loader.dataset)
else:
return len(dataloader.dataset)
def train(self, model_path: Optional[str] = None): def train(self, model_path: Optional[str] = None):
""" """
...@@ -466,7 +453,14 @@ class Trainer: ...@@ -466,7 +453,14 @@ class Trainer:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) if is_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
...@@ -514,24 +508,28 @@ class Trainer: ...@@ -514,24 +508,28 @@ class Trainer:
if self.args.evaluate_during_training: if self.args.evaluate_during_training:
self.evaluate() self.evaluate()
if self.is_world_master(): if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference
# In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save.
# to the model we want to save. if hasattr(model, "module"):
if hasattr(model, "module"): assert model.module is self.model
assert model.module is self.model else:
else: assert model is self.model
assert model is self.model # Save model checkpoint
# Save model checkpoint output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" self.save_model(output_dir)
)
if self.is_world_master():
self.save_model(output_dir)
self._rotate_checkpoints() self._rotate_checkpoints()
if is_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_master():
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close() epoch_iterator.close()
...@@ -713,6 +711,7 @@ class Trainer: ...@@ -713,6 +711,7 @@ class Trainer:
In that case, this method will also return metrics, like in evaluate(). In that case, this method will also return metrics, like in evaluate().
""" """
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
return self._prediction_loop(test_dataloader, description="Prediction") return self._prediction_loop(test_dataloader, description="Prediction")
def _prediction_loop( def _prediction_loop(
...@@ -735,10 +734,7 @@ class Trainer: ...@@ -735,10 +734,7 @@ class Trainer:
# Note: in torch.distributed mode, there's no point in wrapping the model # Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways. # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if is_tpu_available(): batch_size = dataloader.batch_size
batch_size = dataloader._loader._loader.batch_size
else:
batch_size = dataloader.batch_size
logger.info("***** Running %s *****", description) logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size) logger.info(" Batch size = %d", batch_size)
...@@ -747,6 +743,9 @@ class Trainer: ...@@ -747,6 +743,9 @@ class Trainer:
label_ids: torch.Tensor = None label_ids: torch.Tensor = None
model.eval() model.eval()
if is_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment