Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3212b885
You need to sign in or sign up before continuing.
Unverified
Commit
3212b885
authored
Jul 29, 2020
by
Stas Bekman
Committed by
GitHub
Jul 30, 2020
Browse files
[s2s] add support for overriding config params (#6149)
parent
54f9fbef
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
17 deletions
+106
-17
examples/lightning_base.py
examples/lightning_base.py
+23
-0
examples/seq2seq/README.md
examples/seq2seq/README.md
+30
-17
examples/seq2seq/finetune.sh
examples/seq2seq/finetune.sh
+4
-0
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+49
-0
No files found.
examples/lightning_base.py
View file @
3212b885
...
...
@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
)
else
:
self
.
config
:
PretrainedConfig
=
config
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
if
getattr
(
self
.
hparams
,
p
,
None
):
assert
hasattr
(
self
.
config
,
p
),
f
"model config doesn't have a `
{
p
}
` attribute"
setattr
(
self
.
config
,
p
,
getattr
(
self
.
hparams
,
p
))
if
tokenizer
is
None
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
hparams
.
tokenizer_name
if
self
.
hparams
.
tokenizer_name
else
self
.
hparams
.
model_name_or_path
,
...
...
@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
,
)
parser
.
add_argument
(
"--encoder_layerdrop"
,
type
=
float
,
help
=
"Encoder layer dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--decoder_layerdrop"
,
type
=
float
,
help
=
"Decoder layer dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--dropout"
,
type
=
float
,
help
=
"Dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--attention_dropout"
,
type
=
float
,
help
=
"Attention dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
...
...
examples/seq2seq/README.md
View file @
3212b885
...
...
@@ -66,6 +66,19 @@ Summarization Tips:
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.
**
A new dataset is needed to support multilingual tasks.
### Finetuning Training Params
To override the pretrained model's training params, you can pass them to
`./finetune.sh`
:
```
bash
./finetune.sh
\
[
...]
--encoder_layerdrop
0.1
\
--decoder_layerdrop
0.1
\
--dropout
0.1
\
--attention_dropout
0.1
\
```
### Summarization Finetuning
Run/modify
`finetune.sh`
...
...
examples/seq2seq/finetune.sh
View file @
3212b885
...
...
@@ -10,4 +10,8 @@ python finetune.py \
--do_predict
\
--n_val
1000
\
--val_check_interval
0.1
\
--encoder_layerdrop
0.1
\
--decoder_layerdrop
0.1
\
--dropout
0.1
\
--attention_dropout
0.1
\
$@
examples/seq2seq/test_seq2seq_examples.py
View file @
3212b885
...
...
@@ -277,6 +277,55 @@ def test_finetune(model):
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
args_d
.
update
(
data_dir
=
tmp_dir
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
do_predict
=
False
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
# test models whose config includes the extra_model_args
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
args_d1
=
args_d
.
copy
()
args_d1
.
update
(
model_name_or_path
=
model
,
output_dir
=
output_dir
,
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
args_d1
[
p
]
=
0.5
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
for
p
in
extra_model_params
:
assert
getattr
(
model
.
config
,
p
)
==
0.5
,
f
"failed to override the model config for param
{
p
}
"
# test models whose config doesn't include the extra_model_args
model
=
T5_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_2_"
)
args_d2
=
args_d
.
copy
()
args_d2
.
update
(
model_name_or_path
=
model
,
output_dir
=
output_dir
,
)
unsupported_param
=
"encoder_layerdrop"
args_d2
[
unsupported_param
]
=
0.5
args
=
argparse
.
Namespace
(
**
args_d2
)
with
pytest
.
raises
(
Exception
)
as
excinfo
:
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment