Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61b7ba93
"app/tray/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "a84c05cf9140c2eb288a6c7b56bb1c592bbaacc7"
Unverified
Commit
61b7ba93
authored
Aug 31, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 31, 2020
Browse files
Marian distill scripts + integration test (#6799)
parent
02d09c8f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
14 deletions
+132
-14
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+21
-0
examples/seq2seq/distil_marian_no_teacher.sh
examples/seq2seq/distil_marian_no_teacher.sh
+17
-0
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+91
-13
examples/test_examples.py
examples/test_examples.py
+3
-1
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
0 → 100755
View file @
61b7ba93
#!/usr/bin/env bash
export
PYTHONPATH
=
"../"
:
"
${
PYTHONPATH
}
"
export
WANDB_PROJECT
=
dmar
# export MAX_LEN=128
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
\
--val_check_interval
0.25
\
--teacher
Helsinki-NLP/opus-mt-en-ro
--data_dir
$ENRO_DIR
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--freeze_encoder
--freeze_embeds
\
--model_name_or_path
IGNORED
\
--alpha_hid
=
3.
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
O1
--task
translation
\
"
$@
"
examples/seq2seq/distil_marian_no_teacher.sh
0 → 100755
View file @
61b7ba93
#!/usr/bin/env bash
export
PYTHONPATH
=
"../"
:
"
${
PYTHONPATH
}
"
export
WANDB_PROJECT
=
dmar
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
--no_teacher
\
--val_check_interval
0.25
\
--data_dir
$ENRO_DIR
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--freeze_encoder
--freeze_embeds
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
\
"
$@
"
examples/seq2seq/test_bash_script.py
View file @
61b7ba93
...
...
@@ -10,9 +10,10 @@ import pytorch_lightning as pl
import
timeout_decorator
import
torch
from
transformers
import
BartForConditionalGeneration
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
.distillation
import
BartSummarizationDistiller
,
distill_main
from
.finetune
import
SummarizationModule
,
main
from
.test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
.utils
import
load_json
...
...
@@ -20,6 +21,7 @@ from .utils import load_json
MODEL_NAME
=
MBART_TINY
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
...
...
@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY
def
test_model_download
():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
...
...
@@ -35,34 +38,30 @@ def test_model_download():
def
test_train_mbart_cc25_enro_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"$MAX_LEN"
:
200
,
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly.
# Download is 600MB in previous test.
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
"$@"
,
""
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'
"$@"
'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output
_mbart
"
)
if
CUDA_AVAILABLE
:
gpus
=
1
# torch.cuda.device_count()
else
:
gpus
=
0
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
[
"finetune.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
f
"--gpus=
{
gpus
}
"
,
"--gpus=
1
"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
...
...
@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script():
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO(SS): turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
16
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
[
"distillation.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=1e-3"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
...
...
examples/test_examples.py
View file @
61b7ba93
...
...
@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128
"""
.
split
()
if
torch
.
cuda
.
is_available
():
testargs
+=
[
"--fp16"
,
"--gpus=1"
]
testargs
+=
[
"--gpus=1"
]
if
is_cuda_and_apex_avaliable
():
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
result
=
run_pl_glue
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment