Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
825925df
Unverified
Commit
825925df
authored
Oct 28, 2020
by
Stas Bekman
Committed by
GitHub
Oct 28, 2020
Browse files
[s2s test] cleanup (#8131)
parent
e477eb91
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
101 deletions
+1
-101
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
+1
-101
No files found.
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
View file @
825925df
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
import
logging
import
os
import
os
import
sys
import
sys
from
pathlib
import
Path
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
TestCasePlus
,
execute_subprocess_async
,
require_torch_multigpu
from
transformers.testing_utils
import
TestCasePlus
,
execute_subprocess_async
,
require_torch_multigpu
from
.test_seq2seq_examples
import
CHEAP_ARGS
,
make_test_data_dir
from
.utils
import
load_json
from
.utils
import
load_json
if
is_torch_available
():
import
torch
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logger
=
logging
.
getLogger
()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CHEAP_ARGS
=
{
"max_tokens_per_batch"
:
None
,
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"eval_max_gen_length"
:
None
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
"save_top_k"
:
1
,
"adafactor"
:
True
,
"early_stopping_patience"
:
2
,
"logger_name"
:
"default"
,
"length_penalty"
:
0.5
,
"cache_dir"
:
""
,
"task"
:
"summarization"
,
"num_workers"
:
2
,
"alpha_hid"
:
0
,
"freeze_embeds"
:
True
,
"enc_only"
:
False
,
"tgt_suffix"
:
""
,
"resume_from_checkpoint"
:
None
,
"sortish_sampler"
:
True
,
"student_decoder_layers"
:
1
,
"val_check_interval"
:
1.0
,
"output_dir"
:
""
,
"fp16"
:
False
,
# TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher"
:
False
,
"fp16_opt_level"
:
"O1"
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
"n_tpu_cores"
:
0
,
"max_grad_norm"
:
1.0
,
"do_train"
:
True
,
"do_predict"
:
True
,
"accumulate_grad_batches"
:
1
,
"server_ip"
:
""
,
"server_port"
:
""
,
"seed"
:
42
,
"model_name_or_path"
:
"sshleifer/bart-tiny-random"
,
"config_name"
:
""
,
"tokenizer_name"
:
"facebook/bart-large"
,
"do_lower_case"
:
False
,
"learning_rate"
:
0.3
,
"lr_scheduler"
:
"linear"
,
"weight_decay"
:
0.0
,
"adam_epsilon"
:
1e-08
,
"warmup_steps"
:
0
,
"max_epochs"
:
1
,
"train_batch_size"
:
2
,
"eval_batch_size"
:
2
,
"max_source_length"
:
12
,
"max_target_length"
:
12
,
"val_max_target_length"
:
12
,
"test_max_target_length"
:
12
,
"fast_dev_run"
:
False
,
"no_cache"
:
False
,
"n_train"
:
-
1
,
"n_val"
:
-
1
,
"n_test"
:
-
1
,
"student_encoder_layers"
:
1
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
}
def
_dump_articles
(
path
:
Path
,
articles
:
list
):
content
=
"
\n
"
.
join
(
articles
)
Path
(
path
).
open
(
"w"
).
writelines
(
content
)
ARTICLES
=
[
" Sam ate lunch today."
,
"Sams lunch ingredients."
]
SUMMARIES
=
[
"A very interesting story about what I ate for lunch."
,
"Avocado, celery, turkey, coffee"
]
T5_TINY
=
"patrickvonplaten/t5-tiny-random"
BART_TINY
=
"sshleifer/bart-tiny-random"
MBART_TINY
=
"sshleifer/tiny-mbart"
MARIAN_TINY
=
"sshleifer/tiny-marian-en-de"
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
def
make_test_data_dir
(
tmp_dir
):
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.target"
),
SUMMARIES
)
return
tmp_dir
class
TestSummarizationDistillerMultiGPU
(
TestCasePlus
):
class
TestSummarizationDistillerMultiGPU
(
TestCasePlus
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
return
cls
return
cls
@
require_torch_multigpu
@
require_torch_multigpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment