"docs/vscode:/vscode.git/clone" did not exist on "38f89e595b56c0bbea6e993c3c3705ca502bf884"
Unverified Commit ed224f94 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add single file support for Stable Cascade (#7274)



* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 531e7191
...@@ -81,6 +81,87 @@ SCHEDULER_DEFAULT_CONFIG = { ...@@ -81,6 +81,87 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading", "timestep_spacing": "leading",
} }
STABLE_CASCADE_DEFAULT_CONFIGS = {
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
}
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
if is_stage_c:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
else:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
return state_dict
def infer_stable_cascade_single_file_config(checkpoint):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
config_type = "stage_c_lite"
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
config_type = "stage_c"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
config_type = "stage_b_lite"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
config_type = "stage_b"
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
DIFFUSERS_TO_LDM_MAPPING = { DIFFUSERS_TO_LDM_MAPPING = {
"unet": { "unet": {
"layers": { "layers": {
...@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint( ...@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint(
cache_dir=None, cache_dir=None,
local_files_only=None, local_files_only=None,
revision=None, revision=None,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
return original_config, checkpoint
def load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=False,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
): ):
if os.path.isfile(pretrained_model_link_or_path): if os.path.isfile(pretrained_model_link_or_path):
checkpoint = load_state_dict(pretrained_model_link_or_path) checkpoint = load_state_dict(pretrained_model_link_or_path)
else: else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
checkpoint_path = _get_model_file( checkpoint_path = _get_model_file(
...@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint( ...@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
original_config = fetch_original_config(class_name, checkpoint, original_config_file) return checkpoint
return original_config, checkpoint
def infer_original_config_file(class_name, checkpoint): def infer_original_config_file(class_name, checkpoint):
......
...@@ -42,6 +42,11 @@ from ..utils import ( ...@@ -42,6 +42,11 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
...@@ -896,3 +901,103 @@ class UNet2DConditionLoadersMixin: ...@@ -896,3 +901,103 @@ class UNet2DConditionLoadersMixin:
self.config.encoder_hid_dim_type = "ip_image_proj" self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device) self.to(dtype=self.dtype, device=self.device)
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.unet import FromOriginalUNetMixin
from ...utils import BaseOutput from ...utils import BaseOutput
from ..attention_processor import Attention from ..attention_processor import Attention
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -134,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput): ...@@ -134,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin): class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__)
enable_full_determinism()
@slow
class StableCascadeUNetModelSlowTests(unittest.TestCase):
def tearDown(self) -> None:
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_cascade_unet_prior_single_file_components(self):
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16"
)
unet_config = unet.config
del unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert unet_config[param_name] == param_value
def test_stable_cascade_unet_decoder_single_file_components(self):
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16"
)
unet_config = unet.config
del unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert unet_config[param_name] == param_value
def test_stable_cascade_unet_config_loading(self):
config = StableCascadeUNet.load_config(
pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior"
)
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert single_file_unet_config[param_name] == param_value
@require_torch_gpu
def test_stable_cascade_unet_single_file_prior_forward_pass(self):
dtype = torch.bfloat16
generator = torch.Generator("cpu")
model_inputs = {
"sample": randn_tensor((1, 16, 24, 24), generator=generator.manual_seed(0)).to("cuda", dtype),
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_img": randn_tensor((1, 1, 768), generator=generator.manual_seed(0)).to("cuda", dtype),
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
}
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior",
subfolder="prior",
revision="refs/pr/2",
variant="bf16",
torch_dtype=dtype,
)
unet.to("cuda")
with torch.no_grad():
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del unet
gc.collect()
torch.cuda.empty_cache()
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
single_file_unet.to("cuda")
with torch.no_grad():
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
assert max_diff < 8e-3
@require_torch_gpu
def test_stable_cascade_unet_single_file_decoder_forward_pass(self):
dtype = torch.float32
generator = torch.Generator("cpu")
model_inputs = {
"sample": randn_tensor((1, 4, 256, 256), generator=generator.manual_seed(0)).to("cuda", dtype),
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
}
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade",
subfolder="decoder",
revision="refs/pr/44",
torch_dtype=dtype,
)
unet.to("cuda")
with torch.no_grad():
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del unet
gc.collect()
torch.cuda.empty_cache()
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
single_file_unet.to("cuda")
with torch.no_grad():
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
assert max_diff < 1e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment