"docs/vscode:/vscode.git/clone" did not exist on "a921d8be9dd8b98795b4d8076f3af4f48dc3d24d"
Unverified Commit d934d3d7 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

FlaxDiffusionPipeline & FlaxStableDiffusionPipeline (#559)



* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Use Flax schedulers (typing, docstring)

* Wrap model imports inside availability checks.

* more updates

* make sure flax is not broken

* make style

* more fixes

* up
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@latenitesoft.com>
parent c6629e6f
......@@ -66,6 +66,7 @@ if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
......@@ -76,3 +77,8 @@ if is_flax_available():
)
else:
from .utils.dummy_flax_objects import * # noqa F403
if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
......@@ -306,16 +306,16 @@ class FlaxModelMixin:
# Load model
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif from_pt:
if from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
......
......@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
......@@ -144,7 +144,6 @@ class FlaxSpatialTransformer(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from typing import Dict, List, Optional, Union
import numpy as np
import flax
import PIL
from flax.core.frozen_dict import FrozenDict
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
if is_transformers_available():
from transformers import FlaxPreTrainedModel
INDEX_FILE = "diffusion_flax_model.bin"
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DummyChecker:
def __init__(self):
self.dummy = True
def import_flax_or_no_model(module, class_name):
try:
# 1. First make sure that if a Flax object is present, import this one
class_obj = getattr(module, "Flax" + class_name)
except AttributeError:
# 2. If this doesn't work, it's not a model and we don't append "Flax"
class_obj = getattr(module, class_name)
except AttributeError:
raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
return class_obj
@flax.struct.dataclass
class FlaxImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
class FlaxDiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
[`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
pipelines to:
- enabling/disabling the progress bar for the denoising iteration
Class attributes:
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
"""
config_name = "model_index.json"
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
# set models
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
# TODO: handle inference_state
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class)
if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue
save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
if expects_params:
save_method(
os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
`CompVis/ldm-text2im-large-256`.
- A path to a *directory* containing pipeline weights saved using
[`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
dtype (`str` or `jnp.dtype`, *optional*):
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overritten components are then directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
`"CompVis/stable-diffusion-v1-4"`
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
Examples:
```py
>>> from diffusers import FlaxDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> pipeline = FlaxDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
... )
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
)
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != FlaxDiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
# inference_params
params = {}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "StableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
if from_pt:
class_obj = import_flax_or_no_model(pipeline_module, class_name)
else:
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
if from_pt:
class_obj = import_flax_or_no_model(library, class_name)
else:
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loadable_folder = os.path.join(cached_folder, name)
else:
loaded_sub_model = cached_folder
if issubclass(class_obj, FlaxModelMixin):
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = DummyChecker()
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params
del loaded_sub_model._params
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
loaded_sub_model = load_method(loadable_folder)
params[name] = loaded_sub_model.create_state()
else:
loaded_sub_model = load_method(loadable_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
# TODO: make it compatible with jax.lax
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
from ..utils import is_onnx_available, is_transformers_available
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
......@@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
if is_transformers_available():
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
......@@ -17,3 +17,6 @@ if is_transformers_available():
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
......@@ -37,4 +37,24 @@ if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
from typing import Dict, List, Optional, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`FlaxSchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: FlaxAutoencoderKL,
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.dtype = dtype
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
return text_input.input_ids
def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50,
height: Optional[int] = 512,
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
# TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.sample_size,
self.unet.sample_size,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
rngs={},
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
if debug:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
......@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
......@@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
class DDIMSchedulerState:
# setable values
timesteps: jnp.ndarray
alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
@dataclass
......@@ -112,13 +112,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
......@@ -131,19 +127,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
......@@ -177,9 +178,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
eta: float = 0.0,
use_clipped_model_output: bool = False,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
......@@ -221,41 +219,28 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
alphas_cumprod = state.alphas_cumprod
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
std_dev_t = variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample, state)
......
......@@ -148,7 +148,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
"""
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
......@@ -11,42 +11,56 @@ class FlaxModelMixin(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject):
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDDPMScheduler(metaclass=DummyObject):
class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment