Unverified Commit 5543b30a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[pl_examples] default warmup steps=0 (#5316)

parent bf0d12c2
...@@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule):
else: else:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
self.lr_scheduler.step() self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
def get_tqdm_dict(self): self.logger.log_metrics(lrs)
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb) return self.validation_step(batch, batch_nb)
...@@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument( parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
......
...@@ -64,6 +64,7 @@ The following command should work on a 16GB GPU: ...@@ -64,6 +64,7 @@ The following command should work on a 16GB GPU:
Tips: Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. - 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below) - try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best). - `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
......
...@@ -3,6 +3,7 @@ import glob ...@@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -216,6 +217,8 @@ class SummarizationModule(BaseTransformer): ...@@ -216,6 +217,8 @@ class SummarizationModule(BaseTransformer):
scheduler = get_linear_schedule_with_warmup( scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
) )
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler self.lr_scheduler = scheduler
return dataloader return dataloader
......
...@@ -193,13 +193,12 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -193,13 +193,12 @@ class TestSummarizationDistiller(unittest.TestCase):
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) @pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model): def test_run_eval_bart(model):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
assert not output_file_name.exists() assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles) _dump_articles(input_file_name, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
......
...@@ -16,9 +16,9 @@ python finetune.py \ ...@@ -16,9 +16,9 @@ python finetune.py \
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \ --freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
--max_target_length 142 --val_max_target_length=142 \ --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--data_dir $CNN_DIR \
--model_name_or_path sshleifer/student_cnn_12_6 \ --model_name_or_path sshleifer/student_cnn_12_6 \
--tokenizer_name facebook/bart-large \ --tokenizer_name facebook/bart-large \
--warmup_steps 500 \
--output_dir distilbart-cnn-12-6 \ --output_dir distilbart-cnn-12-6 \
$@ $@
...@@ -16,5 +16,6 @@ python distillation.py \ ...@@ -16,5 +16,6 @@ python distillation.py \
--alpha_hid=3. --length_penalty=0.5 \ --alpha_hid=3. --length_penalty=0.5 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \ --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
--tokenizer_name facebook/bart-large \ --tokenizer_name facebook/bart-large \
--warmup_steps 500 \
--output_dir distilbart_xsum_12_6 \ --output_dir distilbart_xsum_12_6 \
$@ $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment