Unverified Commit aefc0c04 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix benchmark non standard model (#5801)

parent 8ce610bc
...@@ -88,7 +88,7 @@ class PyTorchBenchmark(Benchmark): ...@@ -88,7 +88,7 @@ class PyTorchBenchmark(Benchmark):
if self.args.torchscript: if self.args.torchscript:
config.torchscript = True config.torchscript = True
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1 has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = config.architectures[0] model_class = config.architectures[0]
...@@ -138,7 +138,7 @@ class PyTorchBenchmark(Benchmark): ...@@ -138,7 +138,7 @@ class PyTorchBenchmark(Benchmark):
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name] config = self.config_dict[model_name]
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1 has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = config.architectures[0] model_class = config.architectures[0]
......
...@@ -132,7 +132,7 @@ class TensorFlowBenchmark(Benchmark): ...@@ -132,7 +132,7 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16: if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.") raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1 has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
...@@ -172,7 +172,7 @@ class TensorFlowBenchmark(Benchmark): ...@@ -172,7 +172,7 @@ class TensorFlowBenchmark(Benchmark):
if self.args.fp16: if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.") raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1 has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config: if not self.args.only_pretrain_model and has_model_class_in_config:
try: try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment