"...text-generation-inference.git" did not exist on "8ec57558cd5b7b2ad3eaacdd6295a3db0c9092d0"
Unverified Commit d934d3d7 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

FlaxDiffusionPipeline & FlaxStableDiffusionPipeline (#559)



* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Use Flax schedulers (typing, docstring)

* Wrap model imports inside availability checks.

* more updates

* make sure flax is not broken

* make style

* more fixes

* up
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@latenitesoft.com>
parent c6629e6f
...@@ -66,6 +66,7 @@ if is_flax_available(): ...@@ -66,6 +66,7 @@ if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import ( from .schedulers import (
FlaxDDIMScheduler, FlaxDDIMScheduler,
FlaxDDPMScheduler, FlaxDDPMScheduler,
...@@ -76,3 +77,8 @@ if is_flax_available(): ...@@ -76,3 +77,8 @@ if is_flax_available():
) )
else: else:
from .utils.dummy_flax_objects import * # noqa F403 from .utils.dummy_flax_objects import * # noqa F403
if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
...@@ -306,16 +306,16 @@ class FlaxModelMixin: ...@@ -306,16 +306,16 @@ class FlaxModelMixin:
# Load model # Load model
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): if from_pt:
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
) )
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError( raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .unet_2d import UNet2DModel from ..utils import is_flax_available, is_torch_available
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
...@@ -144,7 +144,6 @@ class FlaxSpatialTransformer(nn.Module): ...@@ -144,7 +144,6 @@ class FlaxSpatialTransformer(nn.Module):
def __call__(self, hidden_states, context, deterministic=True): def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
......
This diff is collapsed.
from ..utils import is_onnx_available, is_transformers_available from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline from .latent_diffusion_uncond import LDMPipeline
...@@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline ...@@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline from .stochastic_karras_ve import KarrasVePipeline
if is_transformers_available(): if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import ( from .stable_diffusion import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
...@@ -17,3 +17,6 @@ if is_transformers_available(): ...@@ -17,3 +17,6 @@ if is_transformers_available():
if is_transformers_available() and is_onnx_available(): if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline from .stable_diffusion import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
...@@ -37,4 +37,24 @@ if is_transformers_available() and is_onnx_available(): ...@@ -37,4 +37,24 @@ if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
from typing import Dict, List, Optional, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`FlaxSchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: FlaxAutoencoderKL,
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.dtype = dtype
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
return text_input.input_ids
def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50,
height: Optional[int] = 512,
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
# TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.sample_size,
self.unet.sample_size,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
rngs={},
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
if debug:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import flax import flax
import jax.numpy as jnp import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: ...@@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
class DDIMSchedulerState: class DDIMSchedulerState:
# setable values # setable values
timesteps: jnp.ndarray timesteps: jnp.ndarray
alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None num_inference_steps: Optional[int] = None
@classmethod @classmethod
def create(cls, num_train_timesteps: int): def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
@dataclass @dataclass
...@@ -112,13 +112,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,13 +112,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
...@@ -131,19 +127,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,19 +127,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod # At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0 # For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or # `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one. # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
)
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -177,9 +178,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -177,9 +178,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
key: random.KeyArray,
eta: float = 0.0,
use_clipped_model_output: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxSchedulerOutput, Tuple]:
""" """
...@@ -221,41 +219,28 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,41 +219,28 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t-1) # 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
alphas_cumprod = state.alphas_cumprod
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. compute variance: "sigma_t(η)" -> see formula (16)
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep) variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
std_dev_t = eta * variance ** (0.5) std_dev_t = variance ** (0.5)
if use_clipped_model_output: # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
......
...@@ -148,7 +148,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -148,7 +148,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4 self.pndm_order = 4
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
""" """
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
...@@ -11,42 +11,56 @@ class FlaxModelMixin(metaclass=DummyObject): ...@@ -11,42 +11,56 @@ class FlaxModelMixin(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject): class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxDDPMScheduler(metaclass=DummyObject): class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject): class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject): class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject): class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject): class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment