Unverified Commit 526827c3 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix scheduler type mismatch (#3041)

When doing generation manually and using guidance_scale as a static
argument.
parent cb63febf
...@@ -245,6 +245,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -245,6 +245,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
latents_shape = ( latents_shape = (
batch_size, batch_size,
self.unet.config.in_channels, self.unet.config.in_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment