"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9442b3ce316878cf24d59905184f47c315d3f083"
Unverified Commit 98f6e1ee authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix model parallelism test (#17439)

parent 7535d92e
...@@ -2203,7 +2203,7 @@ class ModelTesterMixin: ...@@ -2203,7 +2203,7 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
def test_cpu_offload(self): def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5: if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5 config.num_hidden_layers = 5
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -2236,7 +2236,7 @@ class ModelTesterMixin: ...@@ -2236,7 +2236,7 @@ class ModelTesterMixin:
@require_torch_multi_gpu @require_torch_multi_gpu
def test_model_parallelism(self): def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5: if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5 config.num_hidden_layers = 5
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment