Unverified Commit e4325606 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix loading sharded checkpoints when we have variants (#9061)



* Fix loading sharded checkpoint when we have variant

* add test

* remote print

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 926daa30
......@@ -773,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file if not is_sharded else sharded_ckpt_cached_folder,
model_file if not is_sharded else index_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
......@@ -803,7 +803,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file if not is_sharded else sharded_ckpt_cached_folder,
model_file if not is_sharded else index_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
......
......@@ -1121,6 +1121,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_with_variant_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16"
)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment