Unverified Commit 688031c5 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix import with Flax but without PyTorch (#688)

* Don't use `load_state_dict` if torch is not installed.

* Define `SchedulerOutput` to use torch or flax arrays.

* Don't import LMSDiscreteScheduler without torch.

* Create distinct FlaxSchedulerOutput.

* Additional changes required for FlaxSchedulerMixin

* Do not import torch pipelines in Flax.

* Revert "Define `SchedulerOutput` to use torch or flax arrays."

This reverts commit f653140134b74d9ffec46d970eb46925fe3a409d.

* Prefix Flax scheduler outputs for consistency.

* make style

* FlaxSchedulerOutput is now a dataclass.

* Don't use f-string without placeholders.

* Add blank line.

* Style (docstrings)
parent 7d0ba592
...@@ -73,6 +73,7 @@ if is_flax_available(): ...@@ -73,6 +73,7 @@ if is_flax_available():
FlaxKarrasVeScheduler, FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler, FlaxScoreSdeVeScheduler,
) )
else: else:
......
...@@ -27,8 +27,8 @@ from huggingface_hub import hf_hub_download ...@@ -27,8 +27,8 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from . import is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import load_state_dict
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
...@@ -391,6 +391,14 @@ class FlaxModelMixin: ...@@ -391,6 +391,14 @@ class FlaxModelMixin:
) )
if from_pt: if from_pt:
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
"Can't load the model in PyTorch format because PyTorch is not installed. "
"Please, install PyTorch or use native Flax weights."
)
# Step 1: Get the pytorch file # Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file) pytorch_model_file = load_state_dict(model_file)
......
...@@ -30,7 +30,7 @@ from tqdm.auto import tqdm ...@@ -30,7 +30,7 @@ from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
...@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) ...@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"], "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"], "FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
}, },
"transformers": { "transformers": {
...@@ -436,7 +436,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -436,7 +436,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else: else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin): elif issubclass(class_obj, FlaxSchedulerMixin):
loaded_sub_model, scheduler_state = load_method(loadable_folder) loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state params[name] = scheduler_state
else: else:
......
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
if is_torch_available():
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_transformers_available(): if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import ( from .stable_diffusion import (
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import PIL import PIL
from PIL import Image from PIL import Image
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
@dataclass @dataclass
...@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput): ...@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: List[bool] nsfw_content_detected: List[bool]
if is_transformers_available(): if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
......
...@@ -34,10 +34,12 @@ if is_flax_available(): ...@@ -34,10 +34,12 @@ if is_flax_available():
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
else: else:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_flax_objects import * # noqa F403
if is_scipy_available():
if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler from .scheduling_lms_discrete import LMSDiscreteScheduler
else: else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
...@@ -23,7 +23,7 @@ import flax ...@@ -23,7 +23,7 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -68,11 +68,11 @@ class DDIMSchedulerState: ...@@ -68,11 +68,11 @@ class DDIMSchedulerState:
@dataclass @dataclass
class FlaxSchedulerOutput(SchedulerOutput): class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
state: DDIMSchedulerState state: DDIMSchedulerState
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance. diffusion probabilistic models (DDPMs) with non-Markovian guidance.
...@@ -183,7 +183,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,7 +183,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -197,11 +197,11 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,11 +197,11 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
key (`random.KeyArray`): a PRNG key. key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step. eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
if state.num_inference_steps is None: if state.num_inference_steps is None:
...@@ -252,7 +252,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -252,7 +252,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise( def add_noise(
self, self,
......
...@@ -23,7 +23,7 @@ import jax.numpy as jnp ...@@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -67,11 +67,11 @@ class DDPMSchedulerState: ...@@ -67,11 +67,11 @@ class DDPMSchedulerState:
@dataclass @dataclass
class FlaxSchedulerOutput(SchedulerOutput): class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
state: DDPMSchedulerState state: DDPMSchedulerState
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling. Langevin dynamics sampling.
...@@ -191,7 +191,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,7 +191,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
key: random.KeyArray, key: random.KeyArray,
predict_epsilon: bool = True, predict_epsilon: bool = True,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -205,11 +205,11 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,11 +205,11 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
key (`random.KeyArray`): a PRNG key. key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`): predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon. optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
t = timestep t = timestep
...@@ -257,7 +257,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -257,7 +257,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (pred_prev_sample, state) return (pred_prev_sample, state)
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state) return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
def add_noise( def add_noise(
self, self,
......
...@@ -22,7 +22,7 @@ from jax import random ...@@ -22,7 +22,7 @@ from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin from .scheduling_utils_flax import FlaxSchedulerMixin
@flax.struct.dataclass @flax.struct.dataclass
...@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput): ...@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
state: KarrasVeSchedulerState state: KarrasVeSchedulerState
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference. the VE column of Table 1 from [1] for reference.
...@@ -172,7 +172,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -172,7 +172,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns: Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
...@@ -211,7 +211,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -211,7 +211,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns: Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
......
...@@ -20,7 +20,7 @@ import jax.numpy as jnp ...@@ -20,7 +20,7 @@ import jax.numpy as jnp
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass @flax.struct.dataclass
...@@ -37,11 +37,11 @@ class LMSDiscreteSchedulerState: ...@@ -37,11 +37,11 @@ class LMSDiscreteSchedulerState:
@dataclass @dataclass
class FlaxSchedulerOutput(SchedulerOutput): class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState state: LMSDiscreteSchedulerState
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson: Katherine Crowson:
...@@ -147,7 +147,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,7 +147,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: jnp.ndarray, sample: jnp.ndarray,
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -159,11 +159,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -159,11 +159,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
order: coefficient for multi-step inference. order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
sigma = state.sigmas[timestep] sigma = state.sigmas[timestep]
...@@ -189,7 +189,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,7 +189,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise( def add_noise(
self, self,
......
...@@ -23,7 +23,7 @@ import jax ...@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
...@@ -76,11 +76,11 @@ class PNDMSchedulerState: ...@@ -76,11 +76,11 @@ class PNDMSchedulerState:
@dataclass @dataclass
class FlaxSchedulerOutput(SchedulerOutput): class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
state: PNDMSchedulerState state: PNDMSchedulerState
class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method. namely Runge-Kutta method and a linear multi-step method.
...@@ -211,7 +211,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -211,7 +211,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -224,11 +224,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,11 +224,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
if self.config.skip_prk_steps: if self.config.skip_prk_steps:
...@@ -249,7 +249,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -249,7 +249,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
def step_prk( def step_prk(
self, self,
...@@ -257,7 +257,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -257,7 +257,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
...@@ -268,11 +268,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -268,11 +268,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
if state.num_inference_steps is None: if state.num_inference_steps is None:
...@@ -327,7 +327,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -327,7 +327,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
...@@ -338,11 +338,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -338,11 +338,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns: Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
if state.num_inference_steps is None: if state.num_inference_steps is None:
......
...@@ -22,7 +22,7 @@ import jax.numpy as jnp ...@@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass @flax.struct.dataclass
...@@ -38,7 +38,7 @@ class ScoreSdeVeSchedulerState: ...@@ -38,7 +38,7 @@ class ScoreSdeVeSchedulerState:
@dataclass @dataclass
class FlaxSdeVeOutput(SchedulerOutput): class FlaxSdeVeOutput(FlaxSchedulerOutput):
""" """
Output class for the ScoreSdeVeScheduler's step function output. Output class for the ScoreSdeVeScheduler's step function output.
...@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput): ...@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput):
prev_sample_mean: Optional[jnp.ndarray] = None prev_sample_mean: Optional[jnp.ndarray] = None
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
The variance exploding stochastic differential equation (SDE) scheduler. The variance exploding stochastic differential equation (SDE) scheduler.
...@@ -168,7 +168,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,7 +168,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns: Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
...@@ -216,7 +216,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -216,7 +216,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample: jnp.ndarray, sample: jnp.ndarray,
key: random.KeyArray, key: random.KeyArray,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[FlaxSdeVeOutput, Tuple]:
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
...@@ -227,7 +227,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,7 +227,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns: Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import jax.numpy as jnp
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class FlaxSchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: jnp.ndarray
class FlaxSchedulerMixin:
"""
Mixin containing common functions for the schedulers.
"""
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment