Unverified Commit 5e24982e authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Upgrade PyTorch Lightning to 1.0.2 (#7852)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 1b6c8d48
...@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
def generic_train( def generic_train(
model: BaseTransformer, model: BaseTransformer,
args: argparse.Namespace, args: argparse.Namespace,
early_stopping_callback=False, early_stopping_callback=None,
logger=True, # can pass WandbLogger() here logger=True, # can pass WandbLogger() here
extra_callbacks=[], extra_callbacks=[],
checkpoint_callback=None, checkpoint_callback=None,
...@@ -355,6 +355,8 @@ def generic_train( ...@@ -355,6 +355,8 @@ def generic_train(
checkpoint_callback = pl.callbacks.ModelCheckpoint( checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
) )
if early_stopping_callback:
extra_callbacks.append(early_stopping_callback)
if logging_callback is None: if logging_callback is None:
logging_callback = LoggingCallback() logging_callback = LoggingCallback()
...@@ -376,7 +378,6 @@ def generic_train( ...@@ -376,7 +378,6 @@ def generic_train(
callbacks=[logging_callback] + extra_callbacks, callbacks=[logging_callback] + extra_callbacks,
logger=logger, logger=logger,
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping_callback,
**train_params, **train_params,
) )
......
...@@ -5,7 +5,7 @@ psutil ...@@ -5,7 +5,7 @@ psutil
sacrebleu sacrebleu
rouge-score rouge-score
tensorflow_datasets tensorflow_datasets
pytorch-lightning==0.9.0 pytorch-lightning==1.0.4
matplotlib matplotlib
git-python==1.0.3 git-python==1.0.3
faiss-cpu faiss-cpu
......
...@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa ...@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="min" if "loss" in metric else "max", mode="min" if "loss" in metric else "max",
save_top_k=save_top_k, save_top_k=save_top_k,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
return checkpoint_callback return checkpoint_callback
......
...@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer): ...@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
return self._generative_step(batch) return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict: def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1 self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"] loss = losses["loss"]
...@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer): ...@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path) dataset = self.get_dataset(type_path)
if self.hparams.sortish_sampler and type_path != "test": if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
return DataLoader( return DataLoader(
dataset, dataset,
...@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer): ...@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
sampler=sampler, sampler=sampler,
) )
elif self.hparams.max_tokens_per_batch is not None and type_path != "test": elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
batch_sampler = dataset.make_dynamic_sampler( batch_sampler = dataset.make_dynamic_sampler(
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1 self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
) )
......
...@@ -144,6 +144,7 @@ class TestAll(TestCasePlus): ...@@ -144,6 +144,7 @@ class TestAll(TestCasePlus):
f"--num_train_epochs={epochs}", f"--num_train_epochs={epochs}",
"--warmup_steps=10", "--warmup_steps=10",
"--val_check_interval=1.0", "--val_check_interval=1.0",
"--do_predict",
] ]
) )
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -151,7 +152,6 @@ class TestAll(TestCasePlus): ...@@ -151,7 +152,6 @@ class TestAll(TestCasePlus):
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu # assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args) model = distill_main(args)
......
...@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus): ...@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
print(metrics) print(metrics)
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float) self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
self.assertEqual(len(metrics["test"]), 1) self.assertEqual(len(metrics["test"]), 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1) desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
......
...@@ -192,7 +192,7 @@ def main(): ...@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir # Optionally, predict on dev set and write to output_dir
if args.do_predict: if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1]) model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model) return trainer.test(model)
......
...@@ -207,9 +207,9 @@ if __name__ == "__main__": ...@@ -207,9 +207,9 @@ if __name__ == "__main__":
if args.do_predict: if args.do_predict:
# See https://github.com/huggingface/transformers/issues/3159 # See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint: # pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169 # /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1]) model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model) trainer.test(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment