"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dd9d483d03962fea127f59661f3ae6156e7a91d2"
Unverified Commit aefc0c04 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix benchmark non standard model (#5801)

parent 8ce610bc
......@@ -88,7 +88,7 @@ class PyTorchBenchmark(Benchmark):
if self.args.torchscript:
config.torchscript = True
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
......@@ -138,7 +138,7 @@ class PyTorchBenchmark(Benchmark):
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
......
......@@ -132,7 +132,7 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
......@@ -172,7 +172,7 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment