Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f0435f5a
Unverified
Commit
f0435f5a
authored
Nov 17, 2020
by
Stas Bekman
Committed by
GitHub
Nov 17, 2020
Browse files
these should run fine on multi-gpu (#8582)
parent
36a19915
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
33 deletions
+3
-33
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+1
-4
examples/seq2seq/test_fsmt_bleu_score.py
examples/seq2seq/test_fsmt_bleu_score.py
+1
-8
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-21
No files found.
examples/seq2seq/test_bash_script.py
View file @
f0435f5a
...
@@ -13,7 +13,7 @@ from distillation import SummarizationDistiller, distill_main
...
@@ -13,7 +13,7 @@ from distillation import SummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
transformers
import
MarianMTModel
from
transformers
import
MarianMTModel
from
transformers.file_utils
import
cached_path
from
transformers.file_utils
import
cached_path
from
transformers.testing_utils
import
TestCasePlus
,
require_torch_gpu
,
require_torch_non_multi_gpu_but_fix_me
,
slow
from
transformers.testing_utils
import
TestCasePlus
,
require_torch_gpu
,
slow
from
utils
import
load_json
from
utils
import
load_json
...
@@ -32,7 +32,6 @@ class TestMbartCc25Enro(TestCasePlus):
...
@@ -32,7 +32,6 @@ class TestMbartCc25Enro(TestCasePlus):
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
@
require_torch_non_multi_gpu_but_fix_me
def
test_model_download
(
self
):
def
test_model_download
(
self
):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
...
@@ -40,7 +39,6 @@ class TestMbartCc25Enro(TestCasePlus):
...
@@ -40,7 +39,6 @@ class TestMbartCc25Enro(TestCasePlus):
# @timeout_decorator.timeout(1200)
# @timeout_decorator.timeout(1200)
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
@
require_torch_non_multi_gpu_but_fix_me
def
test_train_mbart_cc25_enro_script
(
self
):
def
test_train_mbart_cc25_enro_script
(
self
):
env_vars_to_replace
=
{
env_vars_to_replace
=
{
"$MAX_LEN"
:
64
,
"$MAX_LEN"
:
64
,
...
@@ -129,7 +127,6 @@ class TestDistilMarianNoTeacher(TestCasePlus):
...
@@ -129,7 +127,6 @@ class TestDistilMarianNoTeacher(TestCasePlus):
@
timeout_decorator
.
timeout
(
600
)
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
@
require_torch_non_multi_gpu_but_fix_me
def
test_opus_mt_distill_script
(
self
):
def
test_opus_mt_distill_script
(
self
):
data_dir
=
f
"
{
self
.
test_file_dir_str
}
/test_data/wmt_en_ro"
data_dir
=
f
"
{
self
.
test_file_dir_str
}
/test_data/wmt_en_ro"
env_vars_to_replace
=
{
env_vars_to_replace
=
{
...
...
examples/seq2seq/test_fsmt_bleu_score.py
View file @
f0435f5a
...
@@ -19,13 +19,7 @@ import unittest
...
@@ -19,13 +19,7 @@ import unittest
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
transformers
import
FSMTForConditionalGeneration
,
FSMTTokenizer
from
transformers
import
FSMTForConditionalGeneration
,
FSMTTokenizer
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
get_tests_dir
,
require_torch
,
slow
,
torch_device
get_tests_dir
,
require_torch
,
require_torch_non_multi_gpu_but_fix_me
,
slow
,
torch_device
,
)
from
utils
import
calculate_bleu
from
utils
import
calculate_bleu
...
@@ -54,7 +48,6 @@ class ModelEvalTester(unittest.TestCase):
...
@@ -54,7 +48,6 @@ class ModelEvalTester(unittest.TestCase):
]
]
)
)
@
slow
@
slow
@
require_torch_non_multi_gpu_but_fix_me
def
test_bleu_scores
(
self
,
pair
,
min_bleu_score
):
def
test_bleu_scores
(
self
,
pair
,
min_bleu_score
):
# note: this test is not testing the best performance since it only evals a small batch
# note: this test is not testing the best performance since it only evals a small batch
# but it should be enough to detect a regression in the output quality
# but it should be enough to detect a regression in the output quality
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
f0435f5a
...
@@ -19,14 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate
...
@@ -19,14 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate
from
run_eval_search
import
run_search
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
TestCasePlus
,
require_torch_gpu
,
slow
CaptureStderr
,
CaptureStdout
,
TestCasePlus
,
require_torch_gpu
,
require_torch_non_multi_gpu_but_fix_me
,
slow
,
)
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
...
@@ -135,7 +128,6 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -135,7 +128,6 @@ class TestSummarizationDistiller(TestCasePlus):
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
@
require_torch_non_multi_gpu_but_fix_me
def
test_hub_configs
(
self
):
def
test_hub_configs
(
self
):
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
...
@@ -153,12 +145,10 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -153,12 +145,10 @@ class TestSummarizationDistiller(TestCasePlus):
failures
.
append
(
m
)
failures
.
append
(
m
)
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
@
require_torch_non_multi_gpu_but_fix_me
def
test_distill_no_teacher
(
self
):
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
)
@
require_torch_non_multi_gpu_but_fix_me
def
test_distill_checkpointing_with_teacher
(
self
):
def
test_distill_checkpointing_with_teacher
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
2
,
student_encoder_layers
=
2
,
...
@@ -183,7 +173,6 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -183,7 +173,6 @@ class TestSummarizationDistiller(TestCasePlus):
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
@
require_torch_non_multi_gpu_but_fix_me
def
test_loss_fn
(
self
):
def
test_loss_fn
(
self
):
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
BART_TINY
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
BART_TINY
)
input_ids
,
mask
=
model
.
dummy_inputs
[
"input_ids"
],
model
.
dummy_inputs
[
"attention_mask"
]
input_ids
,
mask
=
model
.
dummy_inputs
[
"input_ids"
],
model
.
dummy_inputs
[
"attention_mask"
]
...
@@ -204,7 +193,6 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -204,7 +193,6 @@ class TestSummarizationDistiller(TestCasePlus):
# TODO: understand why this breaks
# TODO: understand why this breaks
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
@
require_torch_non_multi_gpu_but_fix_me
def
test_distill_mbart
(
self
):
def
test_distill_mbart
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
2
,
student_encoder_layers
=
2
,
...
@@ -229,7 +217,6 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -229,7 +217,6 @@ class TestSummarizationDistiller(TestCasePlus):
assert
len
(
all_files
)
>
2
assert
len
(
all_files
)
>
2
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
@
require_torch_non_multi_gpu_but_fix_me
def
test_distill_t5
(
self
):
def
test_distill_t5
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
1
,
student_encoder_layers
=
1
,
...
@@ -241,7 +228,6 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -241,7 +228,6 @@ class TestSummarizationDistiller(TestCasePlus):
)
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
)
@
require_torch_non_multi_gpu_but_fix_me
def
test_distill_different_base_models
(
self
):
def
test_distill_different_base_models
(
self
):
updates
=
dict
(
updates
=
dict
(
teacher
=
T5_TINY
,
teacher
=
T5_TINY
,
...
@@ -321,21 +307,18 @@ class TestTheRest(TestCasePlus):
...
@@ -321,21 +307,18 @@ class TestTheRest(TestCasePlus):
# test one model to quickly (no-@slow) catch simple problems and do an
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
# extensive testing of functionality with multiple models as @slow separately
@
require_torch_non_multi_gpu_but_fix_me
def
test_run_eval
(
self
):
def
test_run_eval
(
self
):
self
.
run_eval_tester
(
T5_TINY
)
self
.
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
# any extra models should go into the list here - can be slow
@
parameterized
.
expand
([
BART_TINY
,
MBART_TINY
])
@
parameterized
.
expand
([
BART_TINY
,
MBART_TINY
])
@
slow
@
slow
@
require_torch_non_multi_gpu_but_fix_me
def
test_run_eval_slow
(
self
,
model
):
def
test_run_eval_slow
(
self
,
model
):
self
.
run_eval_tester
(
model
)
self
.
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
parameterized
.
expand
([
T5_TINY
,
MBART_TINY
])
@
parameterized
.
expand
([
T5_TINY
,
MBART_TINY
])
@
slow
@
slow
@
require_torch_non_multi_gpu_but_fix_me
def
test_run_eval_search
(
self
,
model
):
def
test_run_eval_search
(
self
,
model
):
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
...
@@ -386,7 +369,6 @@ class TestTheRest(TestCasePlus):
...
@@ -386,7 +369,6 @@ class TestTheRest(TestCasePlus):
@
parameterized
.
expand
(
@
parameterized
.
expand
(
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
)
@
require_torch_non_multi_gpu_but_fix_me
def
test_finetune
(
self
,
model
):
def
test_finetune
(
self
,
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
...
@@ -438,7 +420,6 @@ class TestTheRest(TestCasePlus):
...
@@ -438,7 +420,6 @@ class TestTheRest(TestCasePlus):
assert
isinstance
(
example_batch
,
dict
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
assert
len
(
example_batch
)
>=
4
@
require_torch_non_multi_gpu_but_fix_me
def
test_finetune_extra_model_args
(
self
):
def
test_finetune_extra_model_args
(
self
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
@@ -489,7 +470,6 @@ class TestTheRest(TestCasePlus):
...
@@ -489,7 +470,6 @@ class TestTheRest(TestCasePlus):
model
=
main
(
args
)
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
@
require_torch_non_multi_gpu_but_fix_me
def
test_finetune_lr_schedulers
(
self
):
def
test_finetune_lr_schedulers
(
self
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment