Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61b7ba93
Unverified
Commit
61b7ba93
authored
Aug 31, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 31, 2020
Browse files
Marian distill scripts + integration test (#6799)
parent
02d09c8f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
14 deletions
+132
-14
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+21
-0
examples/seq2seq/distil_marian_no_teacher.sh
examples/seq2seq/distil_marian_no_teacher.sh
+17
-0
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+91
-13
examples/test_examples.py
examples/test_examples.py
+3
-1
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
0 → 100755
View file @
61b7ba93
#!/usr/bin/env bash
export
PYTHONPATH
=
"../"
:
"
${
PYTHONPATH
}
"
export
WANDB_PROJECT
=
dmar
# export MAX_LEN=128
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
\
--val_check_interval
0.25
\
--teacher
Helsinki-NLP/opus-mt-en-ro
--data_dir
$ENRO_DIR
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--freeze_encoder
--freeze_embeds
\
--model_name_or_path
IGNORED
\
--alpha_hid
=
3.
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
O1
--task
translation
\
"
$@
"
examples/seq2seq/distil_marian_no_teacher.sh
0 → 100755
View file @
61b7ba93
#!/usr/bin/env bash
export
PYTHONPATH
=
"../"
:
"
${
PYTHONPATH
}
"
export
WANDB_PROJECT
=
dmar
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
--no_teacher
\
--val_check_interval
0.25
\
--data_dir
$ENRO_DIR
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--freeze_encoder
--freeze_embeds
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
\
"
$@
"
examples/seq2seq/test_bash_script.py
View file @
61b7ba93
...
...
@@ -10,9 +10,10 @@ import pytorch_lightning as pl
import
timeout_decorator
import
torch
from
transformers
import
BartForConditionalGeneration
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
.distillation
import
BartSummarizationDistiller
,
distill_main
from
.finetune
import
SummarizationModule
,
main
from
.test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
.utils
import
load_json
...
...
@@ -20,6 +21,7 @@ from .utils import load_json
MODEL_NAME
=
MBART_TINY
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
...
...
@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY
def
test_model_download
():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
...
...
@@ -35,34 +38,30 @@ def test_model_download():
def
test_train_mbart_cc25_enro_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"$MAX_LEN"
:
200
,
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly.
# Download is 600MB in previous test.
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
"$@"
,
""
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'
"$@"
'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output
_mbart
"
)
if
CUDA_AVAILABLE
:
gpus
=
1
# torch.cuda.device_count()
else
:
gpus
=
0
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
[
"finetune.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
f
"--gpus=
{
gpus
}
"
,
"--gpus=
1
"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
...
...
@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script():
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO(SS): turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
16
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
[
"distillation.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=1e-3"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
...
...
examples/test_examples.py
View file @
61b7ba93
...
...
@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128
"""
.
split
()
if
torch
.
cuda
.
is_available
():
testargs
+=
[
"--fp16"
,
"--gpus=1"
]
testargs
+=
[
"--gpus=1"
]
if
is_cuda_and_apex_avaliable
():
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
result
=
run_pl_glue
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment