Unverified Commit 0a09af2f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update flax scheduler API (#822)

* update flax scheduler API

* remoev set format

* fix call to scale_model_input

* update flax pndm

* use int32

* update docstr
parent f1d4289b
...@@ -170,6 +170,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -170,6 +170,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0]) timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet.apply( noise_pred = self.unet.apply(
{"params": params["unet"]}, {"params": params["unet"]},
...@@ -189,6 +191,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -189,6 +191,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
) )
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
if debug: if debug:
# run with python for loop # run with python for loop
for i in range(num_inference_steps): for i in range(num_inference_steps):
......
...@@ -141,6 +141,23 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -141,6 +141,23 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# whether we use the final alpha of the "non-previous" one. # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def create_state(self): def create_state(self):
return DDIMSchedulerState.create( return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
......
...@@ -153,6 +153,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -153,6 +153,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4 self.pndm_order = 4
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def create_state(self): def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
...@@ -196,7 +199,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -196,7 +199,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
) )
return state.replace( return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
counter=0, counter=0,
# Reserve space for the state variables # Reserve space for the state variables
cur_model_output=jnp.zeros(shape), cur_model_output=jnp.zeros(shape),
...@@ -204,6 +207,23 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -204,6 +207,23 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
ets=jnp.zeros((4,) + shape), ets=jnp.zeros((4,) + shape),
) )
def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def step( def step(
self, self,
state: PNDMSchedulerState, state: PNDMSchedulerState,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment