Unverified Commit c71c19c5 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

a few fix for shard checkpoints (#8656)



fix
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent adc31940
...@@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder, offload_folder=offload_folder,
offload_state_dict=offload_state_dict, offload_state_dict=offload_state_dict,
dtype=torch_dtype, dtype=torch_dtype,
force_hook=force_hook, force_hooks=force_hook,
strict=True, strict=True,
) )
model._undo_temp_convert_self_to_deprecated_attention_blocks() model._undo_temp_convert_self_to_deprecated_attention_blocks()
......
...@@ -898,6 +898,7 @@ class ModelTesterMixin: ...@@ -898,6 +898,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards) self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir) new_model = self.model_class.from_pretrained(tmp_dir)
new_model = new_model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
...@@ -933,6 +934,7 @@ class ModelTesterMixin: ...@@ -933,6 +934,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards) self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
new_model = new_model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
......
...@@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_load_sharded_checkpoint_from_hub(self): def test_load_sharded_checkpoint_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
assert loaded_model assert loaded_model
...@@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
assert loaded_model assert loaded_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment