Unverified Commit 3212b885 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] add support for overriding config params (#6149)

parent 54f9fbef
...@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule): ...@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
) )
else: else:
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None: if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
...@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule): ...@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
type=str, type=str,
help="Where do you want to store the pre-trained models downloaded from s3", help="Where do you want to store the pre-trained models downloaded from s3",
) )
parser.add_argument(
"--encoder_layerdrop",
type=float,
help="Encoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--decoder_layerdrop",
type=float,
help="Decoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--dropout", type=float, help="Dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
......
...@@ -66,6 +66,19 @@ Summarization Tips: ...@@ -66,6 +66,19 @@ Summarization Tips:
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.** Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
A new dataset is needed to support multilingual tasks. A new dataset is needed to support multilingual tasks.
### Finetuning Training Params
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
```bash
./finetune.sh \
[...]
--encoder_layerdrop 0.1 \
--decoder_layerdrop 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
```
### Summarization Finetuning ### Summarization Finetuning
Run/modify `finetune.sh` Run/modify `finetune.sh`
......
...@@ -10,4 +10,8 @@ python finetune.py \ ...@@ -10,4 +10,8 @@ python finetune.py \
--do_predict \ --do_predict \
--n_val 1000 \ --n_val 1000 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
--encoder_layerdrop 0.1 \
--decoder_layerdrop 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
$@ $@
...@@ -277,6 +277,55 @@ def test_finetune(model): ...@@ -277,6 +277,55 @@ def test_finetune(model):
assert bart.decoder.embed_tokens == bart.shared assert bart.decoder.embed_tokens == bart.shared
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir()
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model, output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = tempfile.mkdtemp(prefix="output_2_")
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model, output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_pack_dataset(): def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment