Unverified Commit 3212b885 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] add support for overriding config params (#6149)

parent 54f9fbef
...@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule): ...@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
) )
else: else:
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None: if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
...@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule): ...@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
type=str, type=str,
help="Where do you want to store the pre-trained models downloaded from s3", help="Where do you want to store the pre-trained models downloaded from s3",
) )
parser.add_argument(
"--encoder_layerdrop",
type=float,
help="Encoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--decoder_layerdrop",
type=float,
help="Decoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--dropout", type=float, help="Dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks. This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Summarization support is more mature than translation support. Summarization support is more mature than translation support.
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR! Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
For `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md). For `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md).
### Data ### Data
...@@ -35,23 +35,23 @@ export ENRO_DIR=${PWD}/wmt_en_ro ...@@ -35,23 +35,23 @@ export ENRO_DIR=${PWD}/wmt_en_ro
this should make a directory called `wmt_en_ro/` with files like `test.source`. this should make a directory called `wmt_en_ro/` with files like `test.source`.
``` ```
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output. The `.source` files are the input, the `.target` files are the desired output.
### Tips and Tricks ### Tips and Tricks
General Tips: General Tips:
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started. - since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below) - try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best). - `fp16_opt_level=O1` (the default works best).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. - In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. - At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- This warning can be safely ignored: - This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). - Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
- Read scripts before you run them! - Read scripts before you run them!
Summarization Tips: Summarization Tips:
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. - (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
...@@ -60,12 +60,25 @@ Summarization Tips: ...@@ -60,12 +60,25 @@ Summarization Tips:
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. - `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18** **Update 2018-07-18**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.** Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
A new dataset is needed to support multilingual tasks. A new dataset is needed to support multilingual tasks.
### Finetuning Training Params
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
```bash
./finetune.sh \
[...]
--encoder_layerdrop 0.1 \
--decoder_layerdrop 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
```
### Summarization Finetuning ### Summarization Finetuning
Run/modify `finetune.sh` Run/modify `finetune.sh`
...@@ -90,7 +103,7 @@ Best performing command: ...@@ -90,7 +103,7 @@ Best performing command:
```bash ```bash
# optionally # optionally
export ENRO_DIR='wmt_en_ro' # Download instructions above export ENRO_DIR='wmt_en_ro' # Download instructions above
# export WANDB_PROJECT="MT" # optional # export WANDB_PROJECT="MT" # optional
export MAX_LEN=200 export MAX_LEN=200
export BS=4 export BS=4
export GAS=8 # gradient accumulation steps export GAS=8 # gradient accumulation steps
...@@ -109,8 +122,8 @@ export BS=4 ...@@ -109,8 +122,8 @@ export BS=4
export GAS=1 # gradient accumulation steps export GAS=1 # gradient accumulation steps
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb ./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
``` ```
### Finetuning Outputs ### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine). As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour: Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
```bash ```bash
...@@ -128,8 +141,8 @@ output_dir ...@@ -128,8 +141,8 @@ output_dir
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned. ├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json │   ├── config.json
│   └── pytorch_model.bin │   └── pytorch_model.bin
├── test_generations.txt ├── test_generations.txt
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done # ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test'] ├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly. ├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
``` ```
...@@ -191,7 +204,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ ...@@ -191,7 +204,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png) ![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works: For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works:
you just copy alternating layers from `bart-large-cnn` and finetune more on the same data. you just copy alternating layers from `bart-large-cnn` and finetune more on the same data.
For the XSUM dataset, that didn’t work as well so we used that same initialization strategy followed by a combination of Distillbert’s ce_loss and the hidden states MSE loss used in the tinybert paper. For the XSUM dataset, that didn’t work as well so we used that same initialization strategy followed by a combination of Distillbert’s ce_loss and the hidden states MSE loss used in the tinybert paper.
...@@ -207,7 +220,7 @@ They are initialized by copying layers from the associated `bart-large-{cnn|xsum ...@@ -207,7 +220,7 @@ They are initialized by copying layers from the associated `bart-large-{cnn|xsum
The command that produced `sshleifer/distilbart-cnn-12-6` is The command that produced `sshleifer/distilbart-cnn-12-6` is
```bash ```bash
./train_distilbart_cnn.sh ./train_distilbart_cnn.sh
``` ```
runtime: 6H on NVIDIA RTX 24GB GPU runtime: 6H on NVIDIA RTX 24GB GPU
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`. *Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
...@@ -223,15 +236,15 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced. ...@@ -223,15 +236,15 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is: The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash ```bash
./train_distilbart_xsum.sh ./train_distilbart_xsum.sh
``` ```
runtime: 13H on V-100 16GB GPU. runtime: 13H on V-100 16GB GPU.
### Contributing ### Contributing
- follow the standard contributing guidelines and code of conduct. - follow the standard contributing guidelines and code of conduct.
- add tests to `test_seq2seq_examples.py` - add tests to `test_seq2seq_examples.py`
- To run only the seq2seq tests, you must be in the root of the repository and run: - To run only the seq2seq tests, you must be in the root of the repository and run:
```bash ```bash
pytest examples/seq2seq/ pytest examples/seq2seq/
``` ```
...@@ -10,4 +10,8 @@ python finetune.py \ ...@@ -10,4 +10,8 @@ python finetune.py \
--do_predict \ --do_predict \
--n_val 1000 \ --n_val 1000 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
--encoder_layerdrop 0.1 \
--decoder_layerdrop 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
$@ $@
...@@ -277,6 +277,55 @@ def test_finetune(model): ...@@ -277,6 +277,55 @@ def test_finetune(model):
assert bart.decoder.embed_tokens == bart.shared assert bart.decoder.embed_tokens == bart.shared
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir()
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model, output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = tempfile.mkdtemp(prefix="output_2_")
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model, output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_pack_dataset(): def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment