Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
be36cf92
Commit
be36cf92
authored
Oct 30, 2019
by
Timothy Liu
Committed by
Lysandre Debut
Oct 31, 2019
Browse files
Added mixed precision support to benchmarks.py
parent
2a5663c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
examples/benchmarks.py
examples/benchmarks.py
+14
-4
No files found.
examples/benchmarks.py
View file @
be36cf92
...
...
@@ -253,18 +253,22 @@ def create_setup_and_compute(model_names: List[str],
average_over
:
int
=
3
,
torchscript
:
bool
=
False
,
xla
:
bool
=
False
,
amp
:
bool
=
False
,
fp16
:
bool
=
False
,
save_to_csv
:
bool
=
False
,
csv_filename
:
str
=
f
"results_
{
round
(
time
())
}
.csv"
):
if
xla
:
tf
.
config
.
optimizer
.
set_jit
(
True
)
if
amp
:
tf
.
config
.
optimizer
.
set_experimental_options
({
"auto_mixed_precision"
:
True
})
if
tensorflow
:
dictionary
=
{
model_name
:
{}
for
model_name
in
model_names
}
results
=
_compute_tensorflow
(
model_names
,
dictionary
,
average_over
)
results
=
_compute_tensorflow
(
model_names
,
dictionary
,
average_over
,
amp
)
else
:
device
=
'cuda'
if
(
gpu
and
torch
.
cuda
.
is_available
())
else
'cpu'
dictionary
=
{
model_name
:
{}
for
model_name
in
model_names
}
results
=
_compute_pytorch
(
model_names
,
dictionary
,
average_over
,
device
,
torchscript
)
results
=
_compute_pytorch
(
model_names
,
dictionary
,
average_over
,
device
,
torchscript
,
fp16
)
print
(
"=========== RESULTS ==========="
)
for
model_name
in
model_names
:
...
...
@@ -302,7 +306,7 @@ def create_setup_and_compute(model_names: List[str],
writer
.
writerow
({
'model'
:
model_name
,
**
model_results
})
def
_compute_pytorch
(
model_names
,
dictionary
,
average_over
,
device
,
torchscript
):
def
_compute_pytorch
(
model_names
,
dictionary
,
average_over
,
device
,
torchscript
,
fp16
):
for
c
,
model_name
in
enumerate
(
model_names
):
print
(
f
"
{
c
+
1
}
/
{
len
(
model_names
)
}
"
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
torchscript
=
torchscript
)
...
...
@@ -319,6 +323,8 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript)
dictionary
[
model_name
][
"results"
]
=
{
i
:
{}
for
i
in
batch_sizes
}
for
batch_size
in
batch_sizes
:
if
fp16
:
model
.
half
()
model
.
to
(
device
)
model
.
eval
()
for
slice_size
in
slice_sizes
:
...
...
@@ -346,7 +352,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript)
return
dictionary
def
_compute_tensorflow
(
model_names
,
dictionary
,
average_over
):
def
_compute_tensorflow
(
model_names
,
dictionary
,
average_over
,
amp
):
for
c
,
model_name
in
enumerate
(
model_names
):
print
(
f
"
{
c
+
1
}
/
{
len
(
model_names
)
}
"
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
...
...
@@ -409,6 +415,8 @@ def main():
"the correct dependencies are "
"installed"
)
parser
.
add_argument
(
"--xla"
,
required
=
False
,
action
=
"store_true"
,
help
=
"TensorFlow only: use XLA acceleration."
)
parser
.
add_argument
(
"--amp"
,
required
=
False
,
action
=
"store_true"
,
help
=
"TensorFlow only: use automatic mixed precision acceleration."
)
parser
.
add_argument
(
"--fp16"
,
required
=
False
,
action
=
"store_true"
,
help
=
"PyTorch only: use FP16 to accelerate inference."
)
parser
.
add_argument
(
"--keras_predict"
,
required
=
False
,
action
=
"store_true"
,
help
=
"Whether to use model.predict "
"instead of model() to do a "
"forward pass."
)
...
...
@@ -442,6 +450,7 @@ def main():
tensorflow
=
False
,
gpu
=
args
.
torch_cuda
,
torchscript
=
args
.
torchscript
,
fp16
=
args
.
fp16
,
save_to_csv
=
args
.
save_to_csv
,
csv_filename
=
args
.
csv_filename
,
average_over
=
args
.
average_over
...
...
@@ -455,6 +464,7 @@ def main():
model_names
=
args
.
models
,
tensorflow
=
True
,
xla
=
args
.
xla
,
amp
=
args
.
amp
,
save_to_csv
=
args
.
save_to_csv
,
csv_filename
=
args
.
csv_filename
,
average_over
=
args
.
average_over
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment