Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3212b885
Unverified
Commit
3212b885
authored
Jul 29, 2020
by
Stas Bekman
Committed by
GitHub
Jul 30, 2020
Browse files
[s2s] add support for overriding config params (#6149)
parent
54f9fbef
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
17 deletions
+106
-17
examples/lightning_base.py
examples/lightning_base.py
+23
-0
examples/seq2seq/README.md
examples/seq2seq/README.md
+30
-17
examples/seq2seq/finetune.sh
examples/seq2seq/finetune.sh
+4
-0
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+49
-0
No files found.
examples/lightning_base.py
View file @
3212b885
...
@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
...
@@ -70,6 +70,13 @@ class BaseTransformer(pl.LightningModule):
)
)
else
:
else
:
self
.
config
:
PretrainedConfig
=
config
self
.
config
:
PretrainedConfig
=
config
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
if
getattr
(
self
.
hparams
,
p
,
None
):
assert
hasattr
(
self
.
config
,
p
),
f
"model config doesn't have a `
{
p
}
` attribute"
setattr
(
self
.
config
,
p
,
getattr
(
self
.
hparams
,
p
))
if
tokenizer
is
None
:
if
tokenizer
is
None
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
hparams
.
tokenizer_name
if
self
.
hparams
.
tokenizer_name
else
self
.
hparams
.
model_name_or_path
,
self
.
hparams
.
tokenizer_name
if
self
.
hparams
.
tokenizer_name
else
self
.
hparams
.
model_name_or_path
,
...
@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
...
@@ -182,6 +189,22 @@ class BaseTransformer(pl.LightningModule):
type
=
str
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
,
)
)
parser
.
add_argument
(
"--encoder_layerdrop"
,
type
=
float
,
help
=
"Encoder layer dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--decoder_layerdrop"
,
type
=
float
,
help
=
"Decoder layer dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--dropout"
,
type
=
float
,
help
=
"Dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--attention_dropout"
,
type
=
float
,
help
=
"Attention dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
...
...
examples/seq2seq/README.md
View file @
3212b885
...
@@ -66,6 +66,19 @@ Summarization Tips:
...
@@ -66,6 +66,19 @@ Summarization Tips:
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.
**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.
**
A new dataset is needed to support multilingual tasks.
A new dataset is needed to support multilingual tasks.
### Finetuning Training Params
To override the pretrained model's training params, you can pass them to
`./finetune.sh`
:
```
bash
./finetune.sh
\
[
...]
--encoder_layerdrop
0.1
\
--decoder_layerdrop
0.1
\
--dropout
0.1
\
--attention_dropout
0.1
\
```
### Summarization Finetuning
### Summarization Finetuning
Run/modify
`finetune.sh`
Run/modify
`finetune.sh`
...
...
examples/seq2seq/finetune.sh
View file @
3212b885
...
@@ -10,4 +10,8 @@ python finetune.py \
...
@@ -10,4 +10,8 @@ python finetune.py \
--do_predict
\
--do_predict
\
--n_val
1000
\
--n_val
1000
\
--val_check_interval
0.1
\
--val_check_interval
0.1
\
--encoder_layerdrop
0.1
\
--decoder_layerdrop
0.1
\
--dropout
0.1
\
--attention_dropout
0.1
\
$@
$@
examples/seq2seq/test_seq2seq_examples.py
View file @
3212b885
...
@@ -277,6 +277,55 @@ def test_finetune(model):
...
@@ -277,6 +277,55 @@ def test_finetune(model):
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
args_d
.
update
(
data_dir
=
tmp_dir
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
do_predict
=
False
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
# test models whose config includes the extra_model_args
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
args_d1
=
args_d
.
copy
()
args_d1
.
update
(
model_name_or_path
=
model
,
output_dir
=
output_dir
,
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
args_d1
[
p
]
=
0.5
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
for
p
in
extra_model_params
:
assert
getattr
(
model
.
config
,
p
)
==
0.5
,
f
"failed to override the model config for param
{
p
}
"
# test models whose config doesn't include the extra_model_args
model
=
T5_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_2_"
)
args_d2
=
args_d
.
copy
()
args_d2
.
update
(
model_name_or_path
=
model
,
output_dir
=
output_dir
,
)
unsupported_param
=
"encoder_layerdrop"
args_d2
[
unsupported_param
]
=
0.5
args
=
argparse
.
Namespace
(
**
args_d2
)
with
pytest
.
raises
(
Exception
)
as
excinfo
:
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
def
test_pack_dataset
():
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment