"docs/vscode:/vscode.git/clone" did not exist on "7291ea0bff57a017e71b1ea8ec01ff19da298bf0"
Commit 1cfd9748 authored by Lysandre's avatar Lysandre
Browse files

Option to benchmark only one of the two libraries

parent 777faa8a
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Benchmarking the library on inference and training """ """ Benchmarking the library on inference and training """
import tensorflow as tf
# If checking the tensors placement # If checking the tensors placement
# tf.debugging.set_log_device_placement(True) # tf.debugging.set_log_device_placement(True)
...@@ -23,15 +22,18 @@ from typing import List ...@@ -23,15 +22,18 @@ from typing import List
import timeit import timeit
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from time import time from time import time
import torch
import argparse import argparse
import csv import csv
if not is_torch_available() or not is_tf_available(): if is_tf_available():
raise ImportError("TensorFlow and Pytorch should be installed on the system.") import tensorflow as tf
from transformers import TFAutoModel
if is_torch_available():
import torch
from transformers import AutoModel
from transformers import AutoConfig, AutoModel, AutoTokenizer, TFAutoModel from transformers import AutoConfig, AutoTokenizer
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
the Director of Hatcheries and Conditioning entered the room, in the the Director of Hatcheries and Conditioning entered the room, in the
...@@ -434,6 +436,7 @@ def main(): ...@@ -434,6 +436,7 @@ def main():
print("Running with arguments", args) print("Running with arguments", args)
if args.torch: if args.torch:
if is_torch_available():
create_setup_and_compute( create_setup_and_compute(
model_names=args.models, model_names=args.models,
tensorflow=False, tensorflow=False,
...@@ -443,8 +446,11 @@ def main(): ...@@ -443,8 +446,11 @@ def main():
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over
) )
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
if args.tensorflow: if args.tensorflow:
if is_tf_available():
create_setup_and_compute( create_setup_and_compute(
model_names=args.models, model_names=args.models,
tensorflow=True, tensorflow=True,
...@@ -453,7 +459,8 @@ def main(): ...@@ -453,7 +459,8 @@ def main():
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over
) )
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__': if __name__ == '__main__':
main() main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment