Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
023f0f37
Unverified
Commit
023f0f37
authored
Oct 22, 2020
by
Stas Bekman
Committed by
GitHub
Oct 22, 2020
Browse files
[s2s trainer] tests to use distributed on multi-gpu machine (#7965)
parent
64b24bb3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
121 additions
and
78 deletions
+121
-78
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+40
-4
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
+12
-74
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+69
-0
No files found.
examples/seq2seq/test_finetune_trainer.py
View file @
023f0f37
import
os
import
sys
from
pathlib
import
Path
from
unittest.mock
import
patch
import
pytest
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
from
.finetune_trainer
import
main
from
.test_seq2seq_examples
import
MBART_TINY
from
.utils
import
execute_async_std
if
is_torch_available
():
import
torch
set_seed
(
42
)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
...
...
@@ -25,7 +33,7 @@ class TestFinetuneTrainer(TestCasePlus):
@
slow
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
10
)
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
...
...
@@ -43,6 +51,8 @@ class TestFinetuneTrainer(TestCasePlus):
assert
"test_results.json"
in
contents
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
# XXX: remove hardcoded path
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
self
.
get_auto_remove_tmp_dir
()
argv
=
f
"""
...
...
@@ -77,8 +87,34 @@ class TestFinetuneTrainer(TestCasePlus):
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
n_gpu
=
torch
.
cuda
.
device_count
()
if
n_gpu
>
1
:
path
=
Path
(
__file__
).
resolve
()
cur_path
=
path
.
parents
[
0
]
path
=
Path
(
__file__
).
resolve
()
examples_path
=
path
.
parents
[
1
]
src_path
=
f
"
{
path
.
parents
[
2
]
}
/src"
env
=
os
.
environ
.
copy
()
env
[
"PYTHONPATH"
]
=
f
"
{
examples_path
}
:
{
src_path
}
:
{
env
.
get
(
'PYTHONPATH'
,
''
)
}
"
distributed_args
=
(
f
"-m torch.distributed.launch --nproc_per_node=
{
n_gpu
}
{
cur_path
}
/finetune_trainer.py"
.
split
()
)
cmd
=
[
sys
.
executable
]
+
distributed_args
+
argv
print
(
"
\n
Running: "
,
" "
.
join
(
cmd
))
result
=
execute_async_std
(
cmd
,
env
=
env
,
stdin
=
None
,
timeout
=
180
,
quiet
=
False
,
echo
=
False
)
assert
result
.
stdout
,
"produced no output"
if
result
.
returncode
>
0
:
pytest
.
fail
(
f
"failed with returncode
{
result
.
returncode
}
"
)
else
:
# 0 or 1 gpu
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
return
output_dir
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
View file @
023f0f37
...
...
@@ -6,11 +6,15 @@ import sys
from
pathlib
import
Path
import
pytest
import
torch
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
TestCasePlus
,
require_torch_multigpu
from
.utils
import
load_json
from
.utils
import
execute_async_std
,
load_json
if
is_torch_available
():
import
torch
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -106,73 +110,6 @@ def make_test_data_dir(tmp_dir):
return
tmp_dir
# XXX: a candidate for testing_utils (python>=3.6)
# https://stackoverflow.com/a/59041913/9201239
import
asyncio
# noqa
class
RunOutput
:
def
__init__
(
self
,
returncode
,
stdout
,
stderr
):
self
.
returncode
=
returncode
self
.
stdout
=
stdout
self
.
stderr
=
stderr
async
def
_read_stream
(
stream
,
callback
):
while
True
:
line
=
await
stream
.
readline
()
if
line
:
callback
(
line
)
else
:
break
async
def
_stream_subprocess
(
cmd
,
env
=
None
,
stdin
=
None
,
timeout
=
None
,
quiet
=
False
,
echo
=
False
)
->
RunOutput
:
if
echo
:
print
(
cmd
)
p
=
await
asyncio
.
create_subprocess_exec
(
cmd
[
0
],
*
cmd
[
1
:],
stdin
=
stdin
,
stdout
=
asyncio
.
subprocess
.
PIPE
,
stderr
=
asyncio
.
subprocess
.
PIPE
,
env
=
env
,
)
out
=
[]
err
=
[]
def
tee
(
line
,
sink
,
pipe
,
label
=
""
):
line
=
line
.
decode
(
"utf-8"
).
rstrip
()
sink
.
append
(
line
)
if
not
quiet
:
print
(
label
,
line
,
file
=
pipe
)
await
asyncio
.
wait
(
[
_read_stream
(
p
.
stdout
,
lambda
l
:
tee
(
l
,
out
,
sys
.
stdout
)),
_read_stream
(
p
.
stderr
,
lambda
l
:
tee
(
l
,
err
,
sys
.
stderr
,
label
=
"stderr:"
)),
],
timeout
=
timeout
,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return
RunOutput
(
await
p
.
wait
(),
out
,
err
)
def
execute_async_std
(
cmd
,
env
=
None
,
stdin
=
None
,
timeout
=
None
,
quiet
=
False
,
echo
=
False
)
->
RunOutput
:
loop
=
asyncio
.
get_event_loop
()
result
=
loop
.
run_until_complete
(
_stream_subprocess
(
cmd
,
env
=
env
,
stdin
=
stdin
,
timeout
=
timeout
,
quiet
=
quiet
,
echo
=
echo
)
)
return
result
class
TestSummarizationDistillerMultiGPU
(
TestCasePlus
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -220,17 +157,18 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
return
f
"--
{
k
}
"
return
f
"--
{
k
}
=
{
v
}
"
cli_args
=
[
x
for
x
in
(
convert
(
k
,
v
)
for
k
,
v
in
args_d
.
items
())
if
len
(
x
)]
cmd
=
[
sys
.
executable
,
"./examples/seq2seq/distillation.py"
]
+
cli_args
print
(
"
\n
Running: "
,
" "
.
join
(
cmd
))
path
=
Path
(
__file__
).
resolve
()
cur_path
=
path
.
parents
[
0
]
examples_path
=
path
.
parents
[
1
]
src_path
=
f
"
{
path
.
parents
[
2
]
}
/src"
env
=
os
.
environ
.
copy
()
env
[
"PYTHONPATH"
]
=
f
"
{
examples_path
}
:
{
src_path
}
:
{
env
.
get
(
'PYTHONPATH'
,
''
)
}
"
cli_args
=
[
x
for
x
in
(
convert
(
k
,
v
)
for
k
,
v
in
args_d
.
items
())
if
len
(
x
)]
cmd
=
[
sys
.
executable
,
f
"
{
cur_path
}
/distillation.py"
]
+
cli_args
print
(
"
\n
Running: "
,
" "
.
join
(
cmd
))
result
=
execute_async_std
(
cmd
,
env
=
env
,
stdin
=
None
,
timeout
=
180
,
quiet
=
False
,
echo
=
False
)
assert
result
.
stdout
,
"produced no output"
...
...
examples/seq2seq/utils.py
View file @
023f0f37
...
...
@@ -5,6 +5,7 @@ import math
import
os
import
pickle
import
socket
import
sys
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
...
...
@@ -643,3 +644,71 @@ def check_output_dir(args, expected_items=0):
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)
# the following code deals with async io between processes
# adapted from https://stackoverflow.com/a/59041913/9201239
import
asyncio
# noqa
class
_RunOutput
:
def
__init__
(
self
,
returncode
,
stdout
,
stderr
):
self
.
returncode
=
returncode
self
.
stdout
=
stdout
self
.
stderr
=
stderr
async
def
_read_stream
(
stream
,
callback
):
while
True
:
line
=
await
stream
.
readline
()
if
line
:
callback
(
line
)
else
:
break
async
def
_stream_subprocess
(
cmd
,
env
=
None
,
stdin
=
None
,
timeout
=
None
,
quiet
=
False
,
echo
=
False
)
->
_RunOutput
:
if
echo
:
print
(
cmd
)
p
=
await
asyncio
.
create_subprocess_exec
(
cmd
[
0
],
*
cmd
[
1
:],
stdin
=
stdin
,
stdout
=
asyncio
.
subprocess
.
PIPE
,
stderr
=
asyncio
.
subprocess
.
PIPE
,
env
=
env
,
)
out
=
[]
err
=
[]
def
tee
(
line
,
sink
,
pipe
,
label
=
""
):
line
=
line
.
decode
(
"utf-8"
).
rstrip
()
sink
.
append
(
line
)
if
not
quiet
:
print
(
label
,
line
,
file
=
pipe
)
await
asyncio
.
wait
(
[
_read_stream
(
p
.
stdout
,
lambda
l
:
tee
(
l
,
out
,
sys
.
stdout
)),
_read_stream
(
p
.
stderr
,
lambda
l
:
tee
(
l
,
err
,
sys
.
stderr
,
label
=
"stderr:"
)),
],
timeout
=
timeout
,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return
_RunOutput
(
await
p
.
wait
(),
out
,
err
)
def
execute_async_std
(
cmd
,
env
=
None
,
stdin
=
None
,
timeout
=
None
,
quiet
=
False
,
echo
=
False
)
->
_RunOutput
:
loop
=
asyncio
.
get_event_loop
()
result
=
loop
.
run_until_complete
(
_stream_subprocess
(
cmd
,
env
=
env
,
stdin
=
stdin
,
timeout
=
timeout
,
quiet
=
quiet
,
echo
=
echo
)
)
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment