Unverified Commit bcd61fd3 authored by timdalxx's avatar timdalxx Committed by GitHub
Browse files

[docs] add docstrings in `pipline_stable_diffusion.py` (#9590)



* fix the issue on flux dreambooth lora training

* update : origin main code

* docs: update pipeline_stable_diffusion docstring

* docs: update pipeline_stable_diffusion docstring

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* fix: style

* fix: style

* fix: copies

* make fix-copies

* remove extra newline

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d27ecc59
...@@ -88,9 +88,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -88,9 +88,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -110,7 +122,7 @@ def retrieve_timesteps( ...@@ -110,7 +122,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -92,9 +92,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -92,9 +92,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -128,7 +140,7 @@ def retrieve_timesteps( ...@@ -128,7 +140,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -105,9 +105,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -105,9 +105,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -141,7 +153,7 @@ def retrieve_timesteps( ...@@ -141,7 +153,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -178,7 +178,7 @@ def retrieve_timesteps( ...@@ -178,7 +178,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -122,7 +122,7 @@ def retrieve_timesteps( ...@@ -122,7 +122,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -86,7 +98,7 @@ def retrieve_timesteps( ...@@ -86,7 +98,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
...@@ -145,7 +157,7 @@ class StableDiffusionPipeline( ...@@ -145,7 +157,7 @@ class StableDiffusionPipeline(
IPAdapterMixin, IPAdapterMixin,
FromSingleFileMixin, FromSingleFileMixin,
): ):
r""" """
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
......
...@@ -119,7 +119,7 @@ def retrieve_timesteps( ...@@ -119,7 +119,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -60,7 +60,7 @@ def retrieve_timesteps( ...@@ -60,7 +60,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -77,7 +77,7 @@ def retrieve_timesteps( ...@@ -77,7 +77,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -98,7 +98,7 @@ def retrieve_timesteps( ...@@ -98,7 +98,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -97,7 +97,7 @@ def retrieve_timesteps( ...@@ -97,7 +97,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -61,9 +61,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -61,9 +61,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -83,7 +95,7 @@ def retrieve_timesteps( ...@@ -83,7 +95,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -61,9 +61,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -61,9 +61,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -83,7 +95,7 @@ def retrieve_timesteps( ...@@ -83,7 +95,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -87,9 +87,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -87,9 +87,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -109,7 +121,7 @@ def retrieve_timesteps( ...@@ -109,7 +121,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -90,9 +90,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -90,9 +90,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -126,7 +138,7 @@ def retrieve_timesteps( ...@@ -126,7 +138,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -101,9 +101,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -101,9 +101,21 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -153,7 +165,7 @@ def retrieve_timesteps( ...@@ -153,7 +165,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -71,7 +71,7 @@ def retrieve_timesteps( ...@@ -71,7 +71,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -127,7 +127,7 @@ def retrieve_timesteps( ...@@ -127,7 +127,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -119,9 +119,21 @@ def _preprocess_adapter_image(image, height, width): ...@@ -119,9 +119,21 @@ def _preprocess_adapter_image(image, height, width):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
...@@ -141,7 +153,7 @@ def retrieve_timesteps( ...@@ -141,7 +153,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
......
...@@ -310,9 +310,21 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s ...@@ -310,9 +310,21 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" r"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
""" """
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment