Unverified Commit 529850ae authored by Nathan Raw's avatar Nathan Raw Committed by GitHub
Browse files

Lightning Updates for v0.8.5 (#5798)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 615be03f
import argparse
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
AdamW,
......@@ -42,14 +39,6 @@ MODEL_MODES = {
}
def set_seed(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)
class BaseTransformer(pl.LightningModule):
def __init__(
self,
......@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule):
):
"""Initialize a model, tokenizer and config."""
super().__init__()
self.hparams = hparams # TODO: move to self.save_hyperparameters()
# TODO: move to self.save_hyperparameters()
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.hparams = hparams
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
......@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
self.logger.log_metrics(lrs)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
......@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def train_dataloader(self):
def setup(self, step):
train_batch_size = self.hparams.train_batch_size
dataloader = self.load_dataset("train", train_batch_size)
t_total = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
dataloader = self.get_dataloader("train", train_batch_size)
self.train_loader = dataloader
self.total_steps = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
self.lr_scheduler = scheduler
return dataloader
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.load_dataset("dev", self.hparams.eval_batch_size)
return self.get_dataloader("dev", self.hparams.eval_batch_size)
def test_dataloader(self):
return self.load_dataset("test", self.hparams.eval_batch_size)
return self.get_dataloader("test", self.hparams.eval_batch_size)
def _feature_file(self, mode):
return os.path.join(
......@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
class LoggingCallback(pl.Callback):
@rank_zero_only
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
......@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
......@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None:
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--fast_dev_run", action="store_true")
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", default=1.0, type=float)
def generic_train(
......@@ -283,10 +265,13 @@ def generic_train(
logging_callback=None,
**extra_train_kwargs
):
pl.seed_everything(args.seed)
# init model
set_seed(args)
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
......@@ -296,38 +281,25 @@ def generic_train(
train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["use_amp"] = args.fp16
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.n_tpu_cores > 0:
global xm
import torch_xla.core.xla_model as xm
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
logger=logger,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpus,
max_epochs=args.num_train_epochs,
early_stop_callback=early_stopping_callback,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[logging_callback] + extra_callbacks,
fast_dev_run=args.fast_dev_run,
val_check_interval=args.val_check_interval,
weights_summary=None,
resume_from_checkpoint=args.resume_from_checkpoint,
early_stop_callback=early_stopping_callback,
**train_params,
)
if args.do_train:
trainer.fit(model)
trainer.logger.log_hyperparams(args)
trainer.logger.save()
return trainer
......@@ -5,7 +5,7 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.8.1
pytorch-lightning==0.8.5
matplotlib
git-python==1.0.3
faiss
......
......@@ -60,7 +60,7 @@ Summarization Tips:
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
......@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
#### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
......@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger wandb
--logger_name wandb
```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
......
......@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
......@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
......@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser
......@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger == "default"
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared":
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train(
model,
args,
......@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)
......@@ -10,5 +10,4 @@ python finetune.py \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
--sortish_sampler \
$@
......@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"logger": "default",
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
......@@ -48,7 +48,7 @@ CHEAP_ARGS = {
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
......@@ -60,7 +60,7 @@ CHEAP_ARGS = {
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
......@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
......@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
......@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model
......
......@@ -17,5 +17,5 @@ python finetune.py \
--model_name_or_path facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 \
--logger wandb --sortish_sampler \
--logger_name wandb --sortish_sampler \
$@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment