Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1652ddad
Unverified
Commit
1652ddad
authored
Oct 16, 2020
by
Stas Bekman
Committed by
GitHub
Oct 16, 2020
Browse files
[seq2seq testing] improve readability (#7845)
parent
466115b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
52 deletions
+32
-52
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+32
-52
No files found.
examples/seq2seq/test_finetune_trainer.py
View file @
1652ddad
...
@@ -47,58 +47,38 @@ def test_finetune_trainer_slow():
...
@@ -47,58 +47,38 @@ def test_finetune_trainer_slow():
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
argv
=
[
argv
=
f
"""
"--model_name_or_path"
,
--model_name_or_path
{
model_name
}
model_name
,
--data_dir
{
data_dir
}
"--data_dir"
,
--output_dir
{
output_dir
}
data_dir
,
--overwrite_output_dir
"--output_dir"
,
--n_train 8
output_dir
,
--n_val 8
"--overwrite_output_dir"
,
--max_source_length
{
max_len
}
"--n_train"
,
--max_target_length
{
max_len
}
"8"
,
--val_max_target_length
{
max_len
}
"--n_val"
,
--do_train
"8"
,
--do_eval
"--max_source_length"
,
--do_predict
max_len
,
--num_train_epochs
{
str
(
num_train_epochs
)
}
"--max_target_length"
,
--per_device_train_batch_size 4
max_len
,
--per_device_eval_batch_size 4
"--val_max_target_length"
,
--learning_rate 3e-4
max_len
,
--warmup_steps 8
"--do_train"
,
--evaluate_during_training
"--do_eval"
,
--predict_with_generate
"--do_predict"
,
--logging_steps 0
"--num_train_epochs"
,
--save_steps
{
str
(
eval_steps
)
}
str
(
num_train_epochs
),
--eval_steps
{
str
(
eval_steps
)
}
"--per_device_train_batch_size"
,
--sortish_sampler
"4"
,
--label_smoothing 0.1
"--per_device_eval_batch_size"
,
--adafactor
"4"
,
--task translation
"--learning_rate"
,
--tgt_lang ro_RO
"3e-4"
,
--src_lang en_XX
"--warmup_steps"
,
"""
.
split
()
"8"
,
# --eval_beams 2
"--evaluate_during_training"
,
"--predict_with_generate"
,
"--logging_steps"
,
0
,
"--save_steps"
,
str
(
eval_steps
),
"--eval_steps"
,
str
(
eval_steps
),
"--sortish_sampler"
,
"--label_smoothing"
,
"0.1"
,
# "--eval_beams",
# "2",
"--adafactor"
,
"--task"
,
"translation"
,
"--tgt_lang"
,
"ro_RO"
,
"--src_lang"
,
"en_XX"
,
]
testargs
=
[
"finetune_trainer.py"
]
+
argv
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment