Unverified Commit 98f6e1ee authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix model parallelism test (#17439)

parent 7535d92e
......@@ -2203,7 +2203,7 @@ class ModelTesterMixin:
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes:
......@@ -2236,7 +2236,7 @@ class ModelTesterMixin:
@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment