Unverified Commit 8909ab4b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix: device map tests for models (#7825)

* fix: device module tests

* remove patch file

* Empty-Commit
parent c1edb03c
...@@ -691,6 +691,9 @@ class ModelTesterMixin: ...@@ -691,6 +691,9 @@ class ModelTesterMixin:
def test_cpu_offload(self): def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -718,6 +721,9 @@ class ModelTesterMixin: ...@@ -718,6 +721,9 @@ class ModelTesterMixin:
def test_disk_offload_without_safetensors(self): def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -728,12 +734,12 @@ class ModelTesterMixin: ...@@ -728,12 +734,12 @@ class ModelTesterMixin:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False) model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[1] * model_size) max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size} max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder # This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[1] * model_size) max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size} max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained( new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
...@@ -749,6 +755,9 @@ class ModelTesterMixin: ...@@ -749,6 +755,9 @@ class ModelTesterMixin:
def test_disk_offload_with_safetensors(self): def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -758,7 +767,7 @@ class ModelTesterMixin: ...@@ -758,7 +767,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
max_size = int(self.model_split_percents[1] * model_size) max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size} max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained( new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
...@@ -774,6 +783,9 @@ class ModelTesterMixin: ...@@ -774,6 +783,9 @@ class ModelTesterMixin:
def test_model_parallelism(self): def test_model_parallelism(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment