Unverified Commit ab3fd671 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Flax pipeline pndm (#583)



* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Allow dtype to be overridden on model load.

This may be a temporary solution until #567 is addressed.

* Convert params to bfloat16 or fp16 after loading.

This deals with the weights, not the model.

* Use Flax schedulers (typing, docstring)

* PNDM: replace control flow with jax functions.

Otherwise jitting/parallelization don't work properly as they don't know
how to deal with traced objects.

I temporarily removed `step_prk`.

* Pass latents shape to scheduler set_timesteps()

PNDMScheduler uses it to reserve space, other schedulers will just
ignore it.

* Wrap model imports inside availability checks.

* Optionally return state in from_config.

Useful for Flax schedulers.

* Do not convert model weights to dtype.

* Re-enable PRK steps with functional implementation.

Values returned still not verified for correctness.

* Remove left over has_state var.

* make style

* Apply suggestion list -> tuple
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestion list -> tuple
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Remove unused comments.

* Use zeros instead of empty.
Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent c070e5f0
...@@ -56,5 +56,6 @@ if is_transformers_available() and is_flax_available(): ...@@ -56,5 +56,6 @@ if is_transformers_available() and is_flax_available():
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool] nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -186,7 +186,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -186,7 +186,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
if debug: if debug:
# run with python for loop # run with python for loop
......
...@@ -19,6 +19,7 @@ from dataclasses import dataclass ...@@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import flax import flax
import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -155,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
def create_state(self): def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -196,8 +197,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -196,8 +197,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
return state.replace( return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
ets=jnp.array([]),
counter=0, counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
cur_sample=jnp.zeros(shape),
ets=jnp.zeros((4,) + shape),
) )
def step( def step(
...@@ -227,22 +231,32 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,22 +231,32 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor. When returning a tuple, the first element is the sample tensor.
""" """
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: if self.config.skip_prk_steps:
return self.step_prk( prev_sample, state = self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict state=state, model_output=model_output, timestep=timestep, sample=sample
) )
else: else:
return self.step_plms( prev_sample, state = jax.lax.switch(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
(self.step_prk, self.step_plms),
# Args to either branch
state,
model_output,
timestep,
sample,
) )
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def step_prk( def step_prk(
self, self,
state: PNDMSchedulerState, state: PNDMSchedulerState,
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxSchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
...@@ -266,34 +280,46 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -266,34 +280,46 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
)
prev_timestep = timestep - diff_to_prev prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4] timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0: def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
state = state.replace( return (
cur_model_output=state.cur_model_output + 1 / 6 * model_output, state.replace(
ets=state.ets.append(model_output), cur_model_output=state.cur_model_output + 1 / 6 * model_output,
cur_sample=sample, ets=state.ets.at[ets_at].set(model_output),
cur_sample=sample,
),
model_output,
) )
elif (self.counter - 1) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0:
model_output = state.cur_model_output + 1 / 6 * model_output
state = state.replace(cur_model_output=0)
# cur_sample should not be `None` def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
cur_sample = state.cur_sample if state.cur_sample is not None else sample return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
model_output = state.cur_model_output + 1 / 6 * model_output
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
state, model_output = jax.lax.switch(
state.counter % 4,
(remainder_0, remainder_1, remainder_2, remainder_3),
# Args to either branch
state,
model_output,
state.counter // 4,
)
cur_sample = state.cur_sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: return (prev_sample, state)
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def step_plms( def step_plms(
self, self,
...@@ -301,7 +327,6 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -301,7 +327,6 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]: ) -> Union[FlaxSchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
...@@ -334,36 +359,91 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -334,36 +359,91 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
) )
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
# Reference:
# if state.counter != 1:
# state.ets.append(model_output)
# else:
# prev_timestep = timestep
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
)
if state.counter != 1: # Reference:
state = state.replace(ets=state.ets.append(model_output)) # if len(state.ets) == 1 and state.counter == 0:
else: # model_output = model_output
prev_timestep = timestep # state.cur_sample = sample
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps # elif len(state.ets) == 1 and state.counter == 1:
# model_output = (model_output + state.ets[-1]) / 2
if len(state.ets) == 1 and state.counter == 0: # sample = state.cur_sample
model_output = model_output # state.cur_sample = None
state = state.replace(cur_sample=sample) # elif len(state.ets) == 2:
elif len(state.ets) == 1 and state.counter == 1: # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
model_output = (model_output + state.ets[-1]) / 2 # elif len(state.ets) == 3:
sample = state.cur_sample # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
state = state.replace(cur_sample=None) # else:
elif len(state.ets) == 2: # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3: def counter_0(state: PNDMSchedulerState):
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 ets = state.ets.at[0].set(model_output)
else: return state.replace(
model_output = (1 / 24) * ( ets=ets,
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] cur_sample=sample,
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
)
def counter_1(state: PNDMSchedulerState):
return state.replace(
cur_model_output=(model_output + state.ets[0]) / 2,
) )
def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(3 * ets[1] - ets[0]) / 2,
cur_sample=sample,
)
def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
cur_sample=sample,
)
def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])
return state.replace(
ets=ets,
cur_model_output=next_model_output,
cur_sample=sample,
)
counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
)
sample = state.cur_sample
model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: return (prev_sample, state)
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
...@@ -379,7 +459,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -379,7 +459,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment