Unverified Commit 95a78328 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix load sharded checkpoint from a subfolder (local path) (#8913)



fix
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c646fbc1
...@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files( ...@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally( _check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
) )
return pretrained_model_name_or_path, sharded_metadata return shards_path, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames allow_patterns = original_shard_filenames
...@@ -492,10 +492,12 @@ def _get_checkpoint_shard_files( ...@@ -492,10 +492,12 @@ def _get_checkpoint_shard_files(
) from e ) from e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files. # If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if local_files_only: elif local_files_only:
_check_if_shards_exist_locally( _check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
) )
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
return cached_folder, sharded_metadata return cached_folder, sharded_metadata
......
...@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub(self): def test_load_sharded_checkpoint_device_map_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local(self): def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend @require_peft_backend
def test_lora(self): def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment