Unverified Commit 376c02e9 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[WIP] lightning_base: support --lr_scheduler with multiple possibilities (#6232)

* support --lr_scheduler with multiple possibilities

* correct the error message

* add a note about supported schedulers

* cleanup

* cleanup2

* needs the argument default

* style

* add another assert in the test

* implement requested changes

* cleanups

* fix relative import

* cleanup
parent d89acd07
......@@ -20,6 +20,10 @@ from transformers import (
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
)
from transformers.optimization import (
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
......@@ -39,6 +43,19 @@ MODEL_MODES = {
}
# update this and the import above to support new schedulers from transformers.optimization
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
# polynomial': '', # TODO
# '': get_constant_schedule, # not supported for now
# '': get_constant_schedule_with_warmup, # not supported for now
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
class BaseTransformer(pl.LightningModule):
def __init__(
self,
......@@ -97,6 +114,14 @@ class BaseTransformer(pl.LightningModule):
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
......@@ -114,10 +139,8 @@ class BaseTransformer(pl.LightningModule):
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
scheduler = self.get_lr_scheduler()
return [optimizer], [scheduler]
def test_step(self, batch, batch_nb):
......@@ -203,6 +226,14 @@ class BaseTransformer(pl.LightningModule):
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--lr_scheduler",
default="linear",
choices=arg_to_scheduler_choices,
metavar=arg_to_scheduler_metavar,
type=str,
help="Learning rate scheduler",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
......
......@@ -8,15 +8,17 @@ from pathlib import Path
from unittest.mock import patch
import pytest
import pytorch_lightning as pl
import torch
from pytest import param
from torch.utils.data import DataLoader
import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
......@@ -60,6 +62,7 @@ CHEAP_ARGS = {
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
......@@ -326,6 +329,65 @@ def test_finetune_extra_model_args():
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_shedulers(capsys):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir()
model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
output_dir=output_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# emulate finetune.py
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = {"--help": True}
# --help test
with pytest.raises(SystemExit) as excinfo:
args = parser.parse_args(args)
assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit
captured = capsys.readouterr()
expected = lightning_base.arg_to_scheduler_metavar
assert expected in captured.out, "--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test
unsupported_param = "non_existing_scheduler"
args = {f"--lr_scheduler={unsupported_param}"}
with pytest.raises(SystemExit) as excinfo:
args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit
captured = capsys.readouterr()
expected = f"invalid choice: '{unsupported_param}'"
assert expected in captured.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
# --lr_scheduler=existing_scheduler test
supported_param = "cosine"
args_d1 = args_d.copy()
args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1)
model = main(args)
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment