Unverified Commit 2bbd5329 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Docs: short section on changing the scheduler in Flax (#2181)



* Short doc on changing the scheduler in Flax.

* Apply fix from @patil-suraj
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

---------
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 68ef0666
...@@ -176,6 +176,7 @@ image ...@@ -176,6 +176,7 @@ image
<br> <br>
</p> </p>
If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.
## Compare schedulers ## Compare schedulers
...@@ -260,3 +261,54 @@ image ...@@ -260,3 +261,54 @@ image
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results. schedulers to compare results.
## Changing the Scheduler in Flax
If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
```Python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
model_id = "runwayml/stable-diffusion-v1-5"
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "a photo of an astronaut riding a horse on mars"
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
<Tip warning={true}>
The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:
- `FlaxLMSDiscreteScheduler`
- `FlaxDDPMScheduler`
</Tip>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment