Unverified Commit 9d2ce253 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

TPU hangs when saving optimizer/scheduler (#4467)

* TPU hangs when saving optimizer/scheduler

* Style

* ParallelLoader is not a DataLoader

* Style

* Addressing @julien-c's comments
parent 49296533
......@@ -242,9 +242,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch,
)
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
......@@ -269,9 +266,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch,
)
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
......@@ -292,9 +286,6 @@ class Trainer:
collate_fn=self.data_collator.collate_batch,
)
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_optimizers(
......@@ -351,15 +342,11 @@ class Trainer:
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get num of examples from a DataLoader, by accessing its Dataset.
"""
if is_tpu_available():
assert isinstance(dataloader, pl.PerDeviceLoader)
return len(dataloader._loader._loader.dataset)
else:
return len(dataloader.dataset)
return len(dataloader.dataset)
def train(self, model_path: Optional[str] = None):
"""
......@@ -466,7 +453,14 @@ class Trainer:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
if is_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
......@@ -514,24 +508,28 @@ class Trainer:
if self.args.evaluate_during_training:
self.evaluate()
if self.is_world_master():
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model
else:
assert model is self.model
# Save model checkpoint
output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
)
self.save_model(output_dir)
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model
else:
assert model is self.model
# Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
self.save_model(output_dir)
if self.is_world_master():
self._rotate_checkpoints()
if is_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_master():
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close()
......@@ -713,6 +711,7 @@ class Trainer:
In that case, this method will also return metrics, like in evaluate().
"""
test_dataloader = self.get_test_dataloader(test_dataset)
return self._prediction_loop(test_dataloader, description="Prediction")
def _prediction_loop(
......@@ -735,10 +734,7 @@ class Trainer:
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if is_tpu_available():
batch_size = dataloader._loader._loader.batch_size
else:
batch_size = dataloader.batch_size
batch_size = dataloader.batch_size
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size)
......@@ -747,6 +743,9 @@ class Trainer:
label_ids: torch.Tensor = None
model.eval()
if is_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment