"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "4915524fa189651a1ab08b44690cc0cb8b772282"
Unverified Commit 93a81a3f authored by camenduru's avatar camenduru Committed by GitHub
Browse files

Fix Flax pipeline: width and height are ignored #838 (#848)

* Fix Flax pipeline: width and height are ignored #838

* Fix Flax pipeline: width and height are ignored
parent 1d3234cb
...@@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = ( latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
batch_size,
self.unet.in_channels,
self.unet.sample_size,
self.unet.sample_size,
)
if latents is None: if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment