"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "2c8aad97fc8d7647ee8b2df2de9312cce0355ef6"
Unverified Commit 9e9ed353 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix loading sharded checkpoints from subfolder (#8798)



* fix load sharded checkpoints from subfolder{

* style

* os.path.join

* add a small test

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 7833ed95
...@@ -221,7 +221,7 @@ def _fetch_index_file( ...@@ -221,7 +221,7 @@ def _fetch_index_file(
local_files_only=local_files_only, local_files_only=local_files_only,
token=token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=None,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
) )
......
...@@ -455,10 +455,13 @@ def _get_checkpoint_shard_files( ...@@ -455,10 +455,13 @@ def _get_checkpoint_shard_files(
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames allow_patterns = original_shard_filenames
if subfolder is not None:
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"] ignore_patterns = ["*.json", "*.md"]
if not local_files_only: if not local_files_only:
# `model_info` call must guarded with the above condition. # `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path) model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
for shard_file in original_shard_filenames: for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present: if not shard_file_present:
...@@ -481,6 +484,8 @@ def _get_checkpoint_shard_files( ...@@ -481,6 +484,8 @@ def _get_checkpoint_shard_files(
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
user_agent=user_agent, user_agent=user_agent,
) )
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError. # we don't have to catch them here. We have also dealt with EntryNotFoundError.
......
...@@ -1045,6 +1045,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1045,6 +1045,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local(self): def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment