Unverified Commit 11542431 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix variant-identification. (#9253)



* fix variant-idenitification.

* fix variant

* fix sharded variant checkpoint loading.

* Apply suggestions from code review

* fixes.

* more fixes.

* remove print.

* fixes

* fixes

* comments

* fixes

* apply suggestions.

* hub_utils.py

* fix test

* updates

* fixes

* fixes

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* updates.

* removep patch file.

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 81cf3b2f
...@@ -31,6 +31,7 @@ from ..utils import ( ...@@ -31,6 +31,7 @@ from ..utils import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
deprecate,
is_accelerate_available, is_accelerate_available,
is_torch_version, is_torch_version,
logging, logging,
...@@ -228,3 +229,67 @@ def _fetch_index_file( ...@@ -228,3 +229,67 @@ def _fetch_index_file(
index_file = None index_file = None
return index_file return index_file
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file.split(".")
split_index = -3 if ".cache" in index_file else -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file = ".".join(splits)
if os.path.exists(index_file):
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
index_file = Path(index_file)
else:
index_file = None
else:
if variant is not None:
index_file_in_repo = Path(
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file_in_repo.split(".")
split_index = -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file_in_repo = ".".join(splits)
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
except (EntryNotFoundError, EnvironmentError):
index_file = None
return index_file
...@@ -54,6 +54,7 @@ from ..utils.hub_utils import ( ...@@ -54,6 +54,7 @@ from ..utils.hub_utils import (
from .model_loading_utils import ( from .model_loading_utils import (
_determine_device_map, _determine_device_map,
_fetch_index_file, _fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model, _load_state_dict_into_model,
load_model_dict_into_meta, load_model_dict_into_meta,
load_state_dict, load_state_dict,
...@@ -309,11 +310,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -309,11 +310,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant) weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".") weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
if len(weight_name_split) in [2, 3]: ".safetensors", "{suffix}.safetensors"
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) )
else:
raise ValueError(f"Invalid {weights_name} provided.")
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
...@@ -624,21 +623,26 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -624,21 +623,26 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
is_sharded = False is_sharded = False
index_file = None index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file( index_file_kwargs = {
is_local=is_local, "is_local": is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path, "pretrained_model_name_or_path": pretrained_model_name_or_path,
subfolder=subfolder or "", "subfolder": subfolder or "",
use_safetensors=use_safetensors, "use_safetensors": use_safetensors,
cache_dir=cache_dir, "cache_dir": cache_dir,
variant=variant, "variant": variant,
force_download=force_download, "force_download": force_download,
proxies=proxies, "proxies": proxies,
local_files_only=local_files_only, "local_files_only": local_files_only,
token=token, "token": token,
revision=revision, "revision": revision,
user_agent=user_agent, "user_agent": user_agent,
commit_hash=commit_hash, "commit_hash": commit_hash,
) }
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file(): if index_file is not None and index_file.is_file():
is_sharded = True is_sharded = True
......
...@@ -50,7 +50,6 @@ from ..utils import ( ...@@ -50,7 +50,6 @@ from ..utils import (
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
BaseOutput, BaseOutput,
PushToHubMixin, PushToHubMixin,
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_torch_npu_available, is_torch_npu_available,
...@@ -58,7 +57,7 @@ from ..utils import ( ...@@ -58,7 +57,7 @@ from ..utils import (
logging, logging,
numpy_to_pil, numpy_to_pil,
) )
from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module from ..utils.torch_utils import is_compiled_module
...@@ -735,6 +734,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -735,6 +734,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
# a warning if detected.
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)
config_dict = cls.load_config(cached_folder) config_dict = cls.load_config(cached_folder)
# pop out "_ignore_files" as it is only needed for download # pop out "_ignore_files" as it is only needed for download
...@@ -745,6 +759,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -745,6 +759,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`. # with variant being `"fp16"`.
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
if len(model_variants) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)
# 3. Load the pipeline class, if using custom module then load it from the hub # 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
...@@ -1251,6 +1268,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1251,6 +1268,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error = e # save error to reraise it if model is not cached locally model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only: if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
config_file = hf_hub_download( config_file = hf_hub_download(
pretrained_model_name, pretrained_model_name,
cls.config_name, cls.config_name,
...@@ -1267,9 +1300,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1267,9 +1300,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# retrieve all folder_names that contain relevant files # retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
diffusers_module = importlib.import_module(__name__.split(".")[0]) diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines") pipelines = getattr(diffusers_module, "pipelines")
...@@ -1292,13 +1322,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1292,13 +1322,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
if len(variant_filenames) == 0 and variant is not None: if len(variant_filenames) == 0 and variant is not None:
deprecation_message = ( error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." raise ValueError(error_message)
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
# remove ignored filenames # remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames) model_filenames = set(model_filenames) - set(ignore_filenames)
......
...@@ -271,8 +271,7 @@ if cache_version < 1: ...@@ -271,8 +271,7 @@ if cache_version < 1:
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None: if variant is not None:
splits = weights_name.split(".") splits = weights_name.split(".")
split_index = -2 if weights_name.endswith(".index.json") else -1 splits = splits[:-1] + [variant] + splits[-1:]
splits = splits[:-split_index] + [variant] + splits[-split_index:]
weights_name = ".".join(splits) weights_name = ".".join(splits)
return weights_name return weights_name
...@@ -502,6 +501,19 @@ def _get_checkpoint_shard_files( ...@@ -502,6 +501,19 @@ def _get_checkpoint_shard_files(
return cached_folder, sharded_metadata return cached_folder, sharded_metadata
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
if not filenames:
filenames = []
for _, _, files in os.walk(folder):
for file in files:
filenames.append(os.path.basename(file))
transformers_index_format = r"\d{5}-of-\d{5}"
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
return any(variant_file_re.match(f) is not None for f in filenames)
class PushToHubMixin: class PushToHubMixin:
""" """
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
......
...@@ -27,8 +27,9 @@ import numpy as np ...@@ -27,8 +27,9 @@ import numpy as np
import requests_mock import requests_mock
import torch import torch
from accelerate.utils import compute_module_sizes from accelerate.utils import compute_module_sizes
from huggingface_hub import ModelCard, delete_repo from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
...@@ -39,7 +40,13 @@ from diffusers.models.attention_processor import ( ...@@ -39,7 +40,13 @@ from diffusers.models.attention_processor import (
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
is_torch_npu_available,
is_xformers_available,
logging,
)
from diffusers.utils.hub_utils import _add_variant from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
...@@ -100,6 +107,52 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -100,6 +107,52 @@ class ModelUtilsTest(unittest.TestCase):
# make sure that error message states what keys are missing # make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception) assert "conv_out.bias" in str(error_context.exception)
@parameterized.expand(
[
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False),
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True),
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False),
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True),
]
)
def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local):
def load_model(path):
kwargs = {"variant": "fp16"}
if subfolder:
kwargs["subfolder"] = subfolder
return UNet2DConditionModel.from_pretrained(path, **kwargs)
with self.assertWarns(FutureWarning) as warning:
if use_local:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = snapshot_download(repo_id=repo_id)
_ = load_model(tmpdirname)
else:
_ = load_model(repo_id)
warning_message = str(warning.warnings[0].message)
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
# Local tests are already covered down below.
@parameterized.expand(
[
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None, "fp16"),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "unet", "fp16"),
("hf-internal-testing/tiny-sd-unet-sharded-no-variants", None, None),
("hf-internal-testing/tiny-sd-unet-sharded-no-variants-subfolder", "unet", None),
]
)
def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder, variant=None):
def load_model():
kwargs = {}
if variant:
kwargs["variant"] = variant
if subfolder:
kwargs["subfolder"] = subfolder
return UNet2DConditionModel.from_pretrained(repo_id, **kwargs)
assert load_model()
def test_cached_files_are_used_when_no_internet(self): def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
...@@ -924,6 +977,7 @@ class ModelTesterMixin: ...@@ -924,6 +977,7 @@ class ModelTesterMixin:
# testing if loading works with the variant when the checkpoint is sharded should be # testing if loading works with the variant when the checkpoint is sharded should be
# enough. # enough.
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
...@@ -976,6 +1030,44 @@ class ModelTesterMixin: ...@@ -976,6 +1030,44 @@ class ModelTesterMixin:
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
# This test is okay without a GPU because we're not running any execution. We're just serializing
# and check if the resultant files are following an expected format.
def test_variant_sharded_ckpt_right_format(self):
for use_safe in [True, False]:
extension = ".safetensors" if use_safe else ".bin"
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model_size = compute_module_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe
)
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)])
self.assertTrue(actual_num_shards == expected_num_shards)
# Check if the variant is present as a substring in the checkpoints.
shard_files = [
file
for file in os.listdir(tmp_dir)
if file.endswith(extension) or ("index" in file and "json" in file)
]
assert all(variant in f for f in shard_files)
# Check if the sharded checkpoints were serialized in the right format.
shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)]
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
@is_staging_test @is_staging_test
class ModelPushToHubTester(unittest.TestCase): class ModelPushToHubTester(unittest.TestCase):
......
...@@ -1036,9 +1036,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1036,9 +1036,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_from_hub(self): @parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device) loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
...@@ -1046,11 +1052,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1046,11 +1052,15 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_from_hub_subfolder(self): @parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained( loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
)
loaded_model = loaded_model.to(torch_device) loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
...@@ -1080,20 +1090,30 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1080,20 +1090,30 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub(self): @parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto") loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu @require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self): @parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained( loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict) new_output = loaded_model(**inputs_dict)
assert loaded_model assert loaded_model
...@@ -1121,18 +1141,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1121,18 +1141,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_with_variant_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16"
)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend @require_peft_backend
def test_lora(self): def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -30,6 +30,7 @@ import requests_mock ...@@ -30,6 +30,7 @@ import requests_mock
import safetensors.torch import safetensors.torch
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download
from parameterized import parameterized from parameterized import parameterized
from PIL import Image from PIL import Image
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -551,6 +552,50 @@ class DownloadTests(unittest.TestCase): ...@@ -551,6 +552,50 @@ class DownloadTests(unittest.TestCase):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files) assert not any(f.endswith(other_format) for f in files)
def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and
# the `text_encoder`. Their checkpoints can be sharded.
for use_safetensors in [True, False]:
for variant in ["fp16", None]:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-variants-right-format",
safety_checker=None,
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# Check for `model_ext` and `variant`.
model_ext = ".safetensors" if use_safetensors else ".bin"
unexpected_ext = ".bin" if use_safetensors else ".safetensors"
model_files = [f for f in files if f.endswith(model_ext)]
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
for is_local in [True, False]:
with CaptureLogger(logger) as cap_logger:
with tempfile.TemporaryDirectory() as tmpdirname:
local_repo_id = repo_id
if is_local:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
_ = DiffusionPipeline.from_pretrained(
local_repo_id,
safety_checker=None,
variant="fp16",
use_safetensors=True,
)
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
def test_download_safetensors_only_variant_exists_for_model(self): def test_download_safetensors_only_variant_exists_for_model(self):
variant = None variant = None
use_safetensors = True use_safetensors = True
...@@ -655,7 +700,7 @@ class DownloadTests(unittest.TestCase): ...@@ -655,7 +700,7 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe)
pipe_2 = StableDiffusionPipeline.from_pretrained( pipe_2 = StableDiffusionPipeline.from_pretrained(
tmpdirname, safe_serialization=use_safe, variant=variant tmpdirname, safe_serialization=use_safe, variant=variant
) )
...@@ -1646,7 +1691,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1646,7 +1691,7 @@ class PipelineFastTests(unittest.TestCase):
def test_error_no_variant_available(self): def test_error_no_variant_available(self):
variant = "fp16" variant = "fp16"
with self.assertRaises(ValueError) as error_context: with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.download( _ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant "hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
) )
......
...@@ -1824,6 +1824,74 @@ class PipelineTesterMixin: ...@@ -1824,6 +1824,74 @@ class PipelineTesterMixin:
# accounts for models that modify the number of inference steps based on strength # accounts for models that modify the number of inference steps based on strength
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module)
]
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_loading_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
variant = "fp16"
def is_nan(tensor):
if tensor.ndimension() == 0:
has_nan = torch.isnan(tensor).item()
else:
has_nan = torch.isnan(tensor).any()
return has_nan
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
model_components_pipe = {
component_name: component
for component_name, component in pipe.components.items()
if isinstance(component, nn.Module)
}
model_components_pipe_loaded = {
component_name: component
for component_name, component in pipe_loaded.components.items()
if isinstance(component, nn.Module)
}
for component_name in model_components_pipe:
pipe_component = model_components_pipe[component_name]
pipe_loaded_component = model_components_pipe_loaded[component_name]
for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()):
# nan check for luminanext (mps).
if not (is_nan(p1) and is_nan(p2)):
self.assertTrue(torch.equal(p1, p2))
def test_loading_with_incorrect_variants_raises_error(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
# Don't save with variants.
pipe.save_pretrained(tmpdir, safe_serialization=False)
with self.assertRaises(ValueError) as error:
_ = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
def test_StableDiffusionMixin_component(self): def test_StableDiffusionMixin_component(self):
"""Any pipeline that have LDMFuncMixin should have vae and unet components.""" """Any pipeline that have LDMFuncMixin should have vae and unet components."""
if not issubclass(self.pipeline_class, StableDiffusionMixin): if not issubclass(self.pipeline_class, StableDiffusionMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment