Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
89a78be5
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "371f6c8fad45413d9059bcad1083339db57062b5"
Unverified
Commit
89a78be5
authored
Jul 16, 2020
by
Patrick von Platen
Committed by
GitHub
Jul 16, 2020
Browse files
fix benchmark for longformer (#5808)
parent
aefc0c04
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
4 deletions
+38
-4
src/transformers/benchmark/benchmark.py
src/transformers/benchmark/benchmark.py
+10
-2
src/transformers/benchmark/benchmark_tf.py
src/transformers/benchmark/benchmark_tf.py
+10
-2
tests/test_benchmark.py
tests/test_benchmark.py
+18
-0
No files found.
src/transformers/benchmark/benchmark.py
View file @
89a78be5
...
@@ -88,7 +88,11 @@ class PyTorchBenchmark(Benchmark):
...
@@ -88,7 +88,11 @@ class PyTorchBenchmark(Benchmark):
if
self
.
args
.
torchscript
:
if
self
.
args
.
torchscript
:
config
.
torchscript
=
True
config
.
torchscript
=
True
has_model_class_in_config
=
hasattr
(
config
,
"architectures"
)
and
len
(
config
.
architectures
)
>
0
has_model_class_in_config
=
(
hasattr
(
config
,
"architectures"
)
and
isinstance
(
config
.
architectures
,
list
)
and
len
(
config
.
architectures
)
>
0
)
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
try
:
try
:
model_class
=
config
.
architectures
[
0
]
model_class
=
config
.
architectures
[
0
]
...
@@ -138,7 +142,11 @@ class PyTorchBenchmark(Benchmark):
...
@@ -138,7 +142,11 @@ class PyTorchBenchmark(Benchmark):
def
_prepare_train_func
(
self
,
model_name
:
str
,
batch_size
:
int
,
sequence_length
:
int
)
->
Callable
[[],
None
]:
def
_prepare_train_func
(
self
,
model_name
:
str
,
batch_size
:
int
,
sequence_length
:
int
)
->
Callable
[[],
None
]:
config
=
self
.
config_dict
[
model_name
]
config
=
self
.
config_dict
[
model_name
]
has_model_class_in_config
=
hasattr
(
config
,
"architectures"
)
and
len
(
config
.
architectures
)
>
0
has_model_class_in_config
=
(
hasattr
(
config
,
"architectures"
)
and
isinstance
(
config
.
architectures
,
list
)
and
len
(
config
.
architectures
)
>
0
)
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
try
:
try
:
model_class
=
config
.
architectures
[
0
]
model_class
=
config
.
architectures
[
0
]
...
...
src/transformers/benchmark/benchmark_tf.py
View file @
89a78be5
...
@@ -132,7 +132,11 @@ class TensorFlowBenchmark(Benchmark):
...
@@ -132,7 +132,11 @@ class TensorFlowBenchmark(Benchmark):
if
self
.
args
.
fp16
:
if
self
.
args
.
fp16
:
raise
NotImplementedError
(
"Mixed precision is currently not supported."
)
raise
NotImplementedError
(
"Mixed precision is currently not supported."
)
has_model_class_in_config
=
hasattr
(
config
,
"architectures"
)
and
len
(
config
.
architectures
)
>
0
has_model_class_in_config
=
(
hasattr
(
config
,
"architectures"
)
and
isinstance
(
config
.
architectures
,
list
)
and
len
(
config
.
architectures
)
>
0
)
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
try
:
try
:
model_class
=
"TF"
+
config
.
architectures
[
0
]
# prepend 'TF' for tensorflow model
model_class
=
"TF"
+
config
.
architectures
[
0
]
# prepend 'TF' for tensorflow model
...
@@ -172,7 +176,11 @@ class TensorFlowBenchmark(Benchmark):
...
@@ -172,7 +176,11 @@ class TensorFlowBenchmark(Benchmark):
if
self
.
args
.
fp16
:
if
self
.
args
.
fp16
:
raise
NotImplementedError
(
"Mixed precision is currently not supported."
)
raise
NotImplementedError
(
"Mixed precision is currently not supported."
)
has_model_class_in_config
=
hasattr
(
config
,
"architectures"
)
and
len
(
config
.
architectures
)
>
0
has_model_class_in_config
=
(
hasattr
(
config
,
"architectures"
)
and
isinstance
(
config
.
architectures
,
list
)
and
len
(
config
.
architectures
)
>
0
)
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
if
not
self
.
args
.
only_pretrain_model
and
has_model_class_in_config
:
try
:
try
:
model_class
=
"TF"
+
config
.
architectures
[
0
]
# prepend 'TF' for tensorflow model
model_class
=
"TF"
+
config
.
architectures
[
0
]
# prepend 'TF' for tensorflow model
...
...
tests/test_benchmark.py
View file @
89a78be5
...
@@ -86,6 +86,24 @@ class BenchmarkTest(unittest.TestCase):
...
@@ -86,6 +86,24 @@ class BenchmarkTest(unittest.TestCase):
self
.
check_results_dict_not_empty
(
results
.
time_inference_result
)
self
.
check_results_dict_not_empty
(
results
.
time_inference_result
)
self
.
check_results_dict_not_empty
(
results
.
memory_inference_result
)
self
.
check_results_dict_not_empty
(
results
.
memory_inference_result
)
def
test_inference_no_model_no_architecuters
(
self
):
MODEL_ID
=
"sshleifer/tiny-gpt2"
config
=
AutoConfig
.
from_pretrained
(
MODEL_ID
)
# set architectures equal to `None`
config
.
architectures
=
None
benchmark_args
=
PyTorchBenchmarkArguments
(
models
=
[
MODEL_ID
],
training
=
True
,
no_inference
=
False
,
sequence_lengths
=
[
8
],
batch_sizes
=
[
1
],
no_multi_process
=
True
,
)
benchmark
=
PyTorchBenchmark
(
benchmark_args
,
configs
=
[
config
])
results
=
benchmark
.
run
()
self
.
check_results_dict_not_empty
(
results
.
time_inference_result
)
self
.
check_results_dict_not_empty
(
results
.
memory_inference_result
)
def
test_train_no_configs
(
self
):
def
test_train_no_configs
(
self
):
MODEL_ID
=
"sshleifer/tiny-gpt2"
MODEL_ID
=
"sshleifer/tiny-gpt2"
benchmark_args
=
PyTorchBenchmarkArguments
(
benchmark_args
=
PyTorchBenchmarkArguments
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment