Unverified Commit 46fac824 authored by Chi's avatar Chi Committed by GitHub
Browse files

Solve missing clip_sample implementation in FlaxDDIMScheduler. (#7017)



* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* Fix bug for mention in this issue section #6901

* Update src/diffusers/schedulers/scheduling_ddim_flax.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Fix linter

* Restore empty line

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent b33b64f5
...@@ -85,7 +85,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -85,7 +85,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas (`jnp.ndarray`, optional): trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, default `True`): set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
...@@ -117,6 +119,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -117,6 +119,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
clip_sample_range: float = 1.0,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
...@@ -267,6 +271,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -267,6 +271,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
" `v_prediction`" " `v_prediction`"
) )
# 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = pred_original_sample.clip(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 4. compute variance: "sigma_t(η)" -> see formula (16) # 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(state, timestep, prev_timestep) variance = self._get_variance(state, timestep, prev_timestep)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment