Unverified Commit 04c7c176 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] make `test_model_parallelism` device-agnostic (#30844)

* enable on xpu

* fix style

* add comment and mps
parent 42d8dd87
...@@ -76,6 +76,7 @@ from transformers.testing_utils import ( ...@@ -76,6 +76,7 @@ from transformers.testing_utils import (
require_safetensors, require_safetensors,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_sdpa, require_torch_sdpa,
slow, slow,
...@@ -3009,8 +3010,11 @@ class ModelTesterMixin: ...@@ -3009,8 +3010,11 @@ class ModelTesterMixin:
param_device = device_map[param_name] param_device = device_map[param_name]
if param_device in ["cpu", "disk"]: if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta")) self.assertEqual(param.device, torch.device("meta"))
elif param_device in ["mps"]:
self.assertEqual(param.device, torch.device("mps"))
else: else:
self.assertEqual(param.device, torch.device(param_device)) # when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
...@@ -3129,7 +3133,7 @@ class ModelTesterMixin: ...@@ -3129,7 +3133,7 @@ class ModelTesterMixin:
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@require_torch_multi_gpu @require_torch_multi_accelerator
def test_model_parallelism(self): def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -3155,7 +3159,6 @@ class ModelTesterMixin: ...@@ -3155,7 +3159,6 @@ class ModelTesterMixin:
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded # Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment