Unverified Commit c71c19c5 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

a few fix for shard checkpoints (#8656)



fix
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent adc31940
......@@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hook=force_hook,
force_hooks=force_hook,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
......
......@@ -898,6 +898,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir)
new_model = new_model.to(torch_device)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
......@@ -933,6 +934,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
new_model = new_model.to(torch_device)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
......
......@@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_load_sharded_checkpoint_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
......@@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment