Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0554776
Unverified
Commit
c0554776
authored
Jun 08, 2020
by
Patrick von Platen
Committed by
GitHub
Jun 08, 2020
Browse files
fix PR (#4810)
parent
e8177479
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
46 deletions
+191
-46
src/transformers/benchmark/benchmark.py
src/transformers/benchmark/benchmark.py
+99
-32
src/transformers/benchmark/benchmark_args_utils.py
src/transformers/benchmark/benchmark_args_utils.py
+6
-0
src/transformers/benchmark/benchmark_utils.py
src/transformers/benchmark/benchmark_utils.py
+35
-11
tests/test_benchmark.py
tests/test_benchmark.py
+51
-3
No files found.
src/transformers/benchmark/benchmark.py
View file @
c0554776
...
...
@@ -18,8 +18,8 @@
"""
import
inspect
import
logging
import
os
import
timeit
from
transformers
import
MODEL_MAPPING
,
MODEL_WITH_LM_HEAD_MAPPING
,
PretrainedConfig
,
is_torch_available
...
...
@@ -52,46 +52,79 @@ class PyTorchBenchmark(Benchmark):
model
.
to
(
self
.
args
.
device
)
model
.
train
()
# encoder-decoder has vocab size saved differently
vocab_size
=
config
.
vocab_size
if
hasattr
(
config
,
"vocab_size"
)
else
config
.
encoder
.
vocab_size
input_ids
=
torch
.
randint
(
model
.
config
.
vocab_size
,
(
batch_size
,
sequence_length
),
dtype
=
torch
.
long
,
device
=
self
.
args
.
device
vocab_size
,
(
batch_size
,
sequence_length
),
dtype
=
torch
.
long
,
device
=
self
.
args
.
device
)
def
compute_loss_and_backprob
():
# TODO: Not all models call labels argument labels => this hack using the function signature should be corrected once all models have a common name for labels
function_argument_names
=
inspect
.
getfullargspec
(
model
.
forward
).
args
if
"labels"
in
function_argument_names
:
def
compute_loss_and_backprob_encoder
():
loss
=
model
(
input_ids
,
labels
=
input_ids
)[
0
]
elif
"lm_labels"
in
function_argument_names
:
loss
=
model
(
input_ids
,
lm_labels
=
input_ids
)[
0
]
elif
"masked_lm_labels"
in
function_argument_names
:
loss
=
model
(
input_ids
,
masked_lm_labels
=
input_ids
)[
0
]
else
:
NotImplementedError
(
f
"
{
model_name
}
does not seem to allow training with labels"
)
loss
.
backward
()
model
.
zero_grad
()
def
compute_loss_and_backprob_encoder_decoder
():
loss
=
model
(
input_ids
,
decoder_input_ids
=
input_ids
,
labels
=
input_ids
)[
0
]
loss
.
backward
()
model
.
zero_grad
()
_train
=
(
compute_loss_and_backprob_encoder_decoder
if
config
.
is_encoder_decoder
else
compute_loss_and_backprob_encoder
)
if
trace_memory
is
True
:
if
self
.
args
.
trace_memory_line_by_line
or
self
.
args
.
n_gpu
==
0
:
if
self
.
args
.
trace_memory_line_by_line
:
trace
=
start_memory_tracing
(
"transformers"
)
else
:
# clear cuda cache
if
self
.
args
.
n_gpu
>
0
:
# clear gpu cache
torch
.
cuda
.
empty_cache
()
if
hasattr
(
torch
.
cuda
,
"max_memory_reserved"
):
torch
.
cuda
.
reset_peak_memory_stats
()
else
:
logger
.
info
(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
torch
.
cuda
.
reset_max_memory_cached
()
# calculate loss and do backpropagation
compute_loss_and_backprob
()
_train
()
if
self
.
args
.
trace_memory_line_by_line
or
self
.
args
.
n_gpu
==
0
:
if
self
.
args
.
trace_memory_line_by_line
:
summary
=
stop_memory_tracing
(
trace
)
memory
=
summary
.
total
else
:
summary
=
None
if
self
.
args
.
n_gpu
>
0
:
# gpu
if
hasattr
(
torch
.
cuda
,
"max_memory_reserved"
):
memory
=
Memory
(
torch
.
cuda
.
max_memory_reserved
())
else
:
logger
.
info
(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
memory
=
Memory
(
torch
.
cuda
.
max_memory_cached
())
memory
=
Memory
(
torch
.
cuda
.
max_memory_reserved
())
else
:
# cpu
try
:
import
psutil
except
(
ImportError
):
logger
.
warning
(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory
=
"N/A"
else
:
process
=
psutil
.
Process
(
os
.
getpid
())
memory
=
Memory
(
process
.
memory_info
().
rss
)
return
memory
return
memory
,
summary
else
:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes
=
timeit
.
repeat
(
lambda
:
compute_loss_and_backprob
()
,
repeat
=
self
.
args
.
repeat
,
number
=
10
,)
runtimes
=
timeit
.
repeat
(
_train
,
repeat
=
self
.
args
.
repeat
,
number
=
10
,)
return
min
(
runtimes
)
/
10.0
except
RuntimeError
as
e
:
self
.
print_fn
(
"Doesn't fit on GPU. {}"
.
format
(
e
))
...
...
@@ -100,18 +133,36 @@ class PyTorchBenchmark(Benchmark):
def
inference
(
self
,
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
False
):
try
:
config
=
self
.
config_dict
[
model_name
]
if
self
.
args
.
with_lm_head
:
model
=
MODEL_WITH_LM_HEAD_MAPPING
[
config
.
__class__
](
config
)
else
:
model
=
MODEL_MAPPING
[
config
.
__class__
](
config
)
model
.
to
(
self
.
args
.
device
)
model
.
eval
()
# encoder-decoder has vocab size saved differently
vocab_size
=
config
.
vocab_size
if
hasattr
(
config
,
"vocab_size"
)
else
config
.
encoder
.
vocab_size
input_ids
=
torch
.
randint
(
config
.
vocab_size
,
(
batch_size
,
sequence_length
),
dtype
=
torch
.
long
,
device
=
self
.
args
.
device
vocab_size
,
(
batch_size
,
sequence_length
),
dtype
=
torch
.
long
,
device
=
self
.
args
.
device
)
def
encoder_decoder_forward
():
model
(
input_ids
,
decoder_input_ids
=
input_ids
)
def
encoder_forward
():
model
(
input_ids
)
_forward
=
encoder_decoder_forward
if
config
.
is_encoder_decoder
else
encoder_forward
if
trace_memory
is
True
:
if
self
.
args
.
trace_memory_line_by_line
or
self
.
args
.
n_gpu
==
0
:
if
self
.
args
.
trace_memory_line_by_line
:
trace
=
start_memory_tracing
(
"transformers"
)
else
:
# clear cuda cache
if
self
.
args
.
n_gpu
>
0
:
# clear gpu cache
torch
.
cuda
.
empty_cache
()
if
hasattr
(
torch
.
cuda
,
"max_memory_reserved"
):
torch
.
cuda
.
reset_peak_memory_stats
()
...
...
@@ -121,12 +172,15 @@ class PyTorchBenchmark(Benchmark):
)
torch
.
cuda
.
reset_max_memory_cached
()
model
(
input_ids
)
_forward
(
)
if
self
.
args
.
trace_memory_line_by_line
or
self
.
args
.
n_gpu
==
0
:
if
self
.
args
.
trace_memory_line_by_line
:
summary
=
stop_memory_tracing
(
trace
)
memory
=
summary
.
total
else
:
summary
=
None
if
self
.
args
.
n_gpu
>
0
:
# gpu
if
hasattr
(
torch
.
cuda
,
"max_memory_reserved"
):
memory
=
Memory
(
torch
.
cuda
.
max_memory_reserved
())
else
:
...
...
@@ -134,11 +188,24 @@ class PyTorchBenchmark(Benchmark):
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
memory
=
Memory
(
torch
.
cuda
.
max_memory_cached
())
else
:
# cpu
try
:
import
psutil
except
(
ImportError
):
logger
.
warning
(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory
=
"N/A"
else
:
process
=
psutil
.
Process
(
os
.
getpid
())
memory
=
Memory
(
process
.
memory_info
().
rss
)
return
memory
return
memory
,
summary
else
:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes
=
timeit
.
repeat
(
lambda
:
model
(
input_ids
)
,
repeat
=
self
.
args
.
repeat
,
number
=
10
,)
runtimes
=
timeit
.
repeat
(
_forward
,
repeat
=
self
.
args
.
repeat
,
number
=
10
,)
return
min
(
runtimes
)
/
10.0
except
RuntimeError
as
e
:
...
...
src/transformers/benchmark/benchmark_args_utils.py
View file @
c0554776
...
...
@@ -61,6 +61,12 @@ class BenchmarkArguments:
save_to_csv
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Save result to a CSV file"
})
log_print
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Save all print statements in a log file"
})
no_env_print
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Don't print environment information"
})
with_lm_head
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
},
)
inference_time_csv_file
:
str
=
field
(
default
=
f
"inference_time_
{
round
(
time
())
}
.csv"
,
metadata
=
{
"help"
:
"CSV filename used if saving time results to csv."
},
...
...
src/transformers/benchmark/benchmark_utils.py
View file @
c0554776
...
...
@@ -36,7 +36,15 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled
=
False
BenchmarkOutput
=
namedtuple
(
"BenchmarkOutput"
,
[
"time_inference_result"
,
"memory_inference_result"
,
"time_train_result"
,
"memory_train_result"
]
"BenchmarkOutput"
,
[
"time_inference_result"
,
"memory_inference_result"
,
"time_train_result"
,
"memory_train_result"
,
"inference_summary"
,
"train_summary"
,
],
)
...
...
@@ -401,15 +409,10 @@ class Benchmark(ABC):
def
print_fn
(
self
):
if
self
.
_print_fn
is
None
:
if
self
.
args
.
log_print
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
filename
=
self
.
args
.
log_filename
,
filemode
=
"a+"
,
format
=
"%(asctime)-15s %(levelname)-8s %(message)s"
,
)
def
print_and_log
(
*
args
):
logging
.
info
(
*
args
)
with
open
(
self
.
args
.
log_filename
,
"a"
)
as
log_file
:
log_file
.
write
(
str
(
*
args
)
+
"
\n
"
)
print
(
*
args
)
self
.
_print_fn
=
print_and_log
...
...
@@ -454,11 +457,15 @@ class Benchmark(ABC):
train_result_time
[
model_name
]
=
copy
.
deepcopy
(
model_dict
)
train_result_memory
[
model_name
]
=
copy
.
deepcopy
(
model_dict
)
inference_summary
=
train_summary
=
None
for
batch_size
in
self
.
args
.
batch_sizes
:
for
sequence_length
in
self
.
args
.
sequence_lengths
:
if
not
self
.
args
.
no_inference
:
if
not
self
.
args
.
no_memory
:
memory
=
self
.
inference
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
True
)
memory
,
inference_summary
=
self
.
inference
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
True
)
inference_result_memory
[
model_name
][
"result"
][
batch_size
][
sequence_length
]
=
memory
if
not
self
.
args
.
no_speed
:
time
=
self
.
inference
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
False
)
...
...
@@ -466,7 +473,9 @@ class Benchmark(ABC):
if
self
.
args
.
training
:
if
not
self
.
args
.
no_memory
:
memory
=
self
.
train
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
True
)
memory
,
train_summary
=
self
.
train
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
True
)
train_result_memory
[
model_name
][
"result"
][
batch_size
][
sequence_length
]
=
memory
if
not
self
.
args
.
no_speed
:
time
=
self
.
inference
(
model_name
,
batch_size
,
sequence_length
,
trace_memory
=
False
)
...
...
@@ -483,6 +492,10 @@ class Benchmark(ABC):
self
.
print_results
(
inference_result_memory
)
self
.
save_to_csv
(
inference_result_memory
,
self
.
args
.
inference_memory_csv_file
)
if
self
.
args
.
trace_memory_line_by_line
:
self
.
print_fn
(
"======= INFERENCE - MEMORY LINE BY LINE TRACE - SUMMARY ======="
)
self
.
print_memory_trace_statistics
(
inference_summary
)
if
self
.
args
.
training
:
if
not
self
.
args
.
no_speed
:
self
.
print_fn
(
"======= TRAIN - SPEED - RESULT ======="
)
...
...
@@ -494,6 +507,10 @@ class Benchmark(ABC):
self
.
print_results
(
train_result_memory
)
self
.
save_to_csv
(
train_result_memory
,
self
.
args
.
train_memory_csv_file
)
if
self
.
args
.
trace_memory_line_by_line
:
self
.
print_fn
(
"======= TRAIN - MEMORY LINE BY LINE TRACE - SUMMARY ======="
)
self
.
print_memory_trace_statistics
(
train_summary
)
if
not
self
.
args
.
no_env_print
:
self
.
print_fn
(
"
\n
======== ENVIRONMENT - INFORMATION ========"
)
self
.
print_fn
(
...
...
@@ -506,7 +523,14 @@ class Benchmark(ABC):
for
key
,
value
in
self
.
environment_info
.
items
():
writer
.
writerow
([
key
,
value
])
return
BenchmarkOutput
(
inference_result_time
,
inference_result_memory
,
train_result_time
,
train_result_memory
)
return
BenchmarkOutput
(
inference_result_time
,
inference_result_memory
,
train_result_time
,
train_result_memory
,
inference_summary
,
train_summary
,
)
@
property
def
environment_info
(
self
):
...
...
tests/test_benchmark.py
View file @
c0554776
...
...
@@ -3,7 +3,7 @@ import tempfile
import
unittest
from
pathlib
import
Path
from
transformers
import
GPT2
Config
,
is_torch_available
from
transformers
import
Auto
Config
,
is_torch_available
from
.utils
import
require_torch
...
...
@@ -45,7 +45,18 @@ class BenchmarkTest(unittest.TestCase):
def
test_inference_with_configs
(
self
):
MODEL_ID
=
"sshleifer/tiny-gpt2"
config
=
GPT2Config
.
from_pretrained
(
MODEL_ID
)
config
=
AutoConfig
.
from_pretrained
(
MODEL_ID
)
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
False
,
no_inference
=
False
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
]
)
benchmark
=
PyTorchBenchmark
(
benchmark_args
,
configs
=
[
config
])
results
=
benchmark
.
run
()
self
.
check_results_dict_not_empty
(
results
.
time_inference_result
)
self
.
check_results_dict_not_empty
(
results
.
memory_inference_result
)
def
test_inference_encoder_decoder_with_configs
(
self
):
MODEL_ID
=
"sshleifer/tinier_bart"
config
=
AutoConfig
.
from_pretrained
(
MODEL_ID
)
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
False
,
no_inference
=
False
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
]
)
...
...
@@ -56,7 +67,18 @@ class BenchmarkTest(unittest.TestCase):
def
test_train_with_configs
(
self
):
MODEL_ID
=
"sshleifer/tiny-gpt2"
config
=
GPT2Config
.
from_pretrained
(
MODEL_ID
)
config
=
AutoConfig
.
from_pretrained
(
MODEL_ID
)
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
True
,
no_inference
=
True
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
]
)
benchmark
=
PyTorchBenchmark
(
benchmark_args
,
configs
=
[
config
])
results
=
benchmark
.
run
()
self
.
check_results_dict_not_empty
(
results
.
time_train_result
)
self
.
check_results_dict_not_empty
(
results
.
memory_train_result
)
def
test_train_encoder_decoder_with_configs
(
self
):
MODEL_ID
=
"sshleifer/tinier_bart"
config
=
AutoConfig
.
from_pretrained
(
MODEL_ID
)
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
True
,
no_inference
=
True
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
]
)
...
...
@@ -88,3 +110,29 @@ class BenchmarkTest(unittest.TestCase):
self
.
assertTrue
(
Path
(
os
.
path
.
join
(
tmp_dir
,
"inf_mem.csv"
)).
exists
())
self
.
assertTrue
(
Path
(
os
.
path
.
join
(
tmp_dir
,
"train_mem.csv"
)).
exists
())
self
.
assertTrue
(
Path
(
os
.
path
.
join
(
tmp_dir
,
"env.csv"
)).
exists
())
def
test_trace_memory
(
self
):
MODEL_ID
=
"sshleifer/tiny-gpt2"
def
_check_summary_is_not_empty
(
summary
):
self
.
assertTrue
(
hasattr
(
summary
,
"sequential"
))
self
.
assertTrue
(
hasattr
(
summary
,
"cumulative"
))
self
.
assertTrue
(
hasattr
(
summary
,
"current"
))
self
.
assertTrue
(
hasattr
(
summary
,
"total"
))
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
True
,
no_inference
=
False
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
],
log_filename
=
os
.
path
.
join
(
tmp_dir
,
"log.txt"
),
log_print
=
True
,
trace_memory_line_by_line
=
True
,
)
benchmark
=
PyTorchBenchmark
(
benchmark_args
)
result
=
benchmark
.
run
()
_check_summary_is_not_empty
(
result
.
inference_summary
)
_check_summary_is_not_empty
(
result
.
train_summary
)
self
.
assertTrue
(
Path
(
os
.
path
.
join
(
tmp_dir
,
"log.txt"
)).
exists
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment