"vscode:/vscode.git/clone" did not exist on "1571110648dc5b0e603316c9ce2b0f16ac85cdbb"
Unverified Commit 8e2c4cd5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Deprecate sample size (#1406)

* up

* up

* fix

* uP

* more fixes

* up

* uP

* up

* up

* uP

* fix final tests
parent bb2c64a0
...@@ -91,9 +91,6 @@ class ConfigMixin: ...@@ -91,9 +91,6 @@ class ConfigMixin:
def register_to_config(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
# Special case for `kwargs` used in deprecation warning added to schedulers # Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way. # or solve in a more general way.
...@@ -462,7 +459,7 @@ class ConfigMixin: ...@@ -462,7 +459,7 @@ class ConfigMixin:
unused_kwargs = {**config_dict, **kwargs} unused_kwargs = {**config_dict, **kwargs}
# 7. Define "hidden" config parameters that were saved for compatible classes # 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")} hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
return init_dict, unused_kwargs, hidden_config_dict return init_dict, unused_kwargs, hidden_config_dict
...@@ -493,6 +490,9 @@ class ConfigMixin: ...@@ -493,6 +490,9 @@ class ConfigMixin:
`str`: String containing all the attributes that make up this configuration instance in JSON format. `str`: String containing all the attributes that make up this configuration instance in JSON format.
""" """
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]): def to_json_file(self, json_file_path: Union[str, os.PathLike]):
...@@ -520,6 +520,7 @@ def register_to_config(init): ...@@ -520,6 +520,7 @@ def register_to_config(init):
def inner_init(self, *args, **kwargs): def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init. # Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
init(self, *args, **init_kwargs) init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin): if not isinstance(self, ConfigMixin):
raise RuntimeError( raise RuntimeError(
...@@ -545,6 +546,7 @@ def register_to_config(init): ...@@ -545,6 +546,7 @@ def register_to_config(init):
if k not in ignore and k not in new_kwargs if k not in ignore and k not in new_kwargs
} }
) )
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs) getattr(self, "register_to_config")(**new_kwargs)
return inner_init return inner_init
...@@ -562,7 +564,7 @@ def flax_register_to_config(cls): ...@@ -562,7 +564,7 @@ def flax_register_to_config(cls):
) )
# Ignore private kwargs in the init. Retrieve all passed attributes # Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} init_kwargs = {k: v for k, v in kwargs.items()}
# Retrieve default values # Retrieve default values
fields = dataclasses.fields(self) fields = dataclasses.fields(self)
......
...@@ -448,7 +448,7 @@ class ModelMixin(torch.nn.Module): ...@@ -448,7 +448,7 @@ class ModelMixin(torch.nn.Module):
if low_cpu_mem_usage: if low_cpu_mem_usage:
# Instantiate model with empty weights # Instantiate model with empty weights
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config( config, unused_kwargs = cls.load_config(
config_path, config_path,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
...@@ -462,6 +462,7 @@ class ModelMixin(torch.nn.Module): ...@@ -462,6 +462,7 @@ class ModelMixin(torch.nn.Module):
device_map=device_map, device_map=device_map,
**kwargs, **kwargs,
) )
model = cls.from_config(config, **unused_kwargs)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu # if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None: if device_map is None:
...@@ -482,7 +483,7 @@ class ModelMixin(torch.nn.Module): ...@@ -482,7 +483,7 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [], "error_msgs": [],
} }
else: else:
model, unused_kwargs = cls.from_config( config, unused_kwargs = cls.load_config(
config_path, config_path,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
...@@ -496,6 +497,7 @@ class ModelMixin(torch.nn.Module): ...@@ -496,6 +497,7 @@ class ModelMixin(torch.nn.Module):
device_map=device_map, device_map=device_map,
**kwargs, **kwargs,
) )
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
......
...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import torch import torch
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -132,6 +133,27 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -132,6 +133,27 @@ class AltDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -145,6 +146,27 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -145,6 +146,27 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -176,6 +177,26 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -176,6 +177,26 @@ class CycleDiffusionPipeline(DiffusionPipeline):
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
......
...@@ -23,6 +23,7 @@ import jax.numpy as jnp ...@@ -23,6 +23,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
...@@ -34,7 +35,7 @@ from ...schedulers import ( ...@@ -34,7 +35,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
) )
from ...utils import logging from ...utils import deprecate, logging
from . import FlaxStableDiffusionPipelineOutput from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -97,6 +98,27 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -97,6 +98,27 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -98,6 +99,27 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -98,6 +99,27 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -134,6 +135,27 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -134,6 +135,27 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -148,6 +149,27 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -148,6 +149,27 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -133,6 +134,27 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -133,6 +134,27 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import torch import torch
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -131,6 +132,27 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -131,6 +132,27 @@ class StableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -19,8 +19,10 @@ import torch ...@@ -19,8 +19,10 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
...@@ -31,7 +33,7 @@ from ...schedulers import ( ...@@ -31,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -100,6 +102,27 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -100,6 +102,27 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
image_encoder=image_encoder, image_encoder=image_encoder,
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -144,6 +145,27 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -144,6 +145,27 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -209,6 +210,27 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -209,6 +210,27 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -157,6 +158,27 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -157,6 +158,27 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union ...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -126,6 +127,27 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -126,6 +127,27 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn ...@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
if warning is not None: if warning is not None:
warning = warning + " " if standard_warn else "" warning = warning + " " if standard_warn else ""
warnings.warn(warning + message, DeprecationWarning) warnings.warn(warning + message, FutureWarning)
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
call_frame = inspect.getouterframes(inspect.currentframe())[1] call_frame = inspect.getouterframes(inspect.currentframe())[1]
......
...@@ -262,18 +262,17 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -262,18 +262,17 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
[prompt], [prompt],
generator=generator, generator=generator,
guidance_scale=6.0, guidance_scale=6.0,
height=536, height=136,
width=536, width=136,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
) )
sd_pipe.enable_attention_slicing()
image = output.images image = output.images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 536, 536, 3) assert image.shape == (1, 136, 136, 3)
expected_slice = np.array([0.5445, 0.8108, 0.6242, 0.4863, 0.5779, 0.5423, 0.4749, 0.4589, 0.4616]) expected_slice = np.array([0.5524, 0.5626, 0.6069, 0.4727, 0.386, 0.3995, 0.4613, 0.4328, 0.4269])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -766,18 +765,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -766,18 +765,18 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prompt = "hey" prompt = "hey"
output = sd_pipe(prompt, number_of_steps=2, output_type="np") output = sd_pipe(prompt, number_of_steps=1, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (64, 64) assert image_shape == (64, 64)
output = sd_pipe(prompt, number_of_steps=2, height=96, width=96, output_type="np") output = sd_pipe(prompt, number_of_steps=1, height=96, width=96, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (96, 96) assert image_shape == (96, 96)
config = dict(sd_pipe.unet.config) config = dict(sd_pipe.unet.config)
config["sample_size"] = 96 config["sample_size"] = 96
sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device) sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device)
output = sd_pipe(prompt, number_of_steps=2, output_type="np") output = sd_pipe(prompt, number_of_steps=1, output_type="np")
image_shape = output.images[0].shape[:2] image_shape = output.images[0].shape[:2]
assert image_shape == (192, 192) assert image_shape == (192, 192)
......
...@@ -26,7 +26,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -26,7 +26,7 @@ class DeprecateTester(unittest.TestCase):
def test_deprecate_function_arg(self): def test_deprecate_function_arg(self):
kwargs = {"deprecated_arg": 4} kwargs = {"deprecated_arg": 4}
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs)
assert output == 4 assert output == 4
...@@ -39,7 +39,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -39,7 +39,7 @@ class DeprecateTester(unittest.TestCase):
def test_deprecate_function_arg_tuple(self): def test_deprecate_function_arg_tuple(self):
kwargs = {"deprecated_arg": 4} kwargs = {"deprecated_arg": 4}
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs)
assert output == 4 assert output == 4
...@@ -51,7 +51,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -51,7 +51,7 @@ class DeprecateTester(unittest.TestCase):
def test_deprecate_function_args(self): def test_deprecate_function_args(self):
kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8}
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
output_1, output_2 = deprecate( output_1, output_2 = deprecate(
("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_1", self.higher_version, "Hey"),
("deprecated_arg_2", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"),
...@@ -81,7 +81,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -81,7 +81,7 @@ class DeprecateTester(unittest.TestCase):
assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception)
def test_deprecate_arg_no_kwarg(self): def test_deprecate_arg_no_kwarg(self):
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
deprecate(("deprecated_arg", self.higher_version, "message")) deprecate(("deprecated_arg", self.higher_version, "message"))
assert ( assert (
...@@ -90,7 +90,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -90,7 +90,7 @@ class DeprecateTester(unittest.TestCase):
) )
def test_deprecate_args_no_kwarg(self): def test_deprecate_args_no_kwarg(self):
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
deprecate( deprecate(
("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_1", self.higher_version, "Hey"),
("deprecated_arg_2", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"),
...@@ -108,7 +108,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -108,7 +108,7 @@ class DeprecateTester(unittest.TestCase):
class Args: class Args:
arg = 5 arg = 5
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) arg = deprecate(("arg", self.higher_version, "message"), take_from=Args())
assert arg == 5 assert arg == 5
...@@ -122,7 +122,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -122,7 +122,7 @@ class DeprecateTester(unittest.TestCase):
arg = 5 arg = 5
foo = 7 foo = 7
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
arg_1, arg_2 = deprecate( arg_1, arg_2 = deprecate(
("arg", self.higher_version, "message"), ("arg", self.higher_version, "message"),
("foo", self.higher_version, "message"), ("foo", self.higher_version, "message"),
...@@ -158,7 +158,7 @@ class DeprecateTester(unittest.TestCase): ...@@ -158,7 +158,7 @@ class DeprecateTester(unittest.TestCase):
) )
def test_deprecate_incorrect_no_standard_warn(self): def test_deprecate_incorrect_no_standard_warn(self):
with self.assertWarns(DeprecationWarning) as warning: with self.assertWarns(FutureWarning) as warning:
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
assert str(warning.warning) == "This message is better!!!" assert str(warning.warning) == "This message is better!!!"
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment