Commit 1cfd9748 authored by Lysandre's avatar Lysandre
Browse files

Option to benchmark only one of the two libraries

parent 777faa8a
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training """
import tensorflow as tf
# If checking the tensors placement
# tf.debugging.set_log_device_placement(True)
......@@ -23,15 +22,18 @@ from typing import List
import timeit
from transformers import is_tf_available, is_torch_available
from time import time
import torch
import argparse
import csv
if not is_torch_available() or not is_tf_available():
raise ImportError("TensorFlow and Pytorch should be installed on the system.")
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModel
if is_torch_available():
import torch
from transformers import AutoModel
from transformers import AutoConfig, AutoModel, AutoTokenizer, TFAutoModel
from transformers import AutoConfig, AutoTokenizer
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
the Director of Hatcheries and Conditioning entered the room, in the
......@@ -434,26 +436,31 @@ def main():
print("Running with arguments", args)
if args.torch:
create_setup_and_compute(
model_names=args.models,
tensorflow=False,
gpu=args.torch_cuda,
torchscript=args.torchscript,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over
)
if is_torch_available():
create_setup_and_compute(
model_names=args.models,
tensorflow=False,
gpu=args.torch_cuda,
torchscript=args.torchscript,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over
)
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
if args.tensorflow:
create_setup_and_compute(
model_names=args.models,
tensorflow=True,
xla=args.xla,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over
)
if is_tf_available():
create_setup_and_compute(
model_names=args.models,
tensorflow=True,
xla=args.xla,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over
)
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__':
main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment