Unverified Commit 7bbbfbfd authored by Juan Acevedo's avatar Juan Acevedo Committed by GitHub
Browse files

Jax infer support negative prompt (#1337)



* support negative prompts in sd jax pipeline

* pass batched neg_prompt

* only encode when negative prompt is None
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>
parent 30220905
...@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None, latents: Optional[jnp.array] = None,
debug: bool = False, debug: bool = False,
neg_prompt_ids: jnp.array = None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
batch_size = prompt_ids.shape[0] batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1] max_length = prompt_ids.shape[-1]
if neg_prompt_ids is None:
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
) ).input_ids
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
...@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
debug: bool = False, debug: bool = False,
neg_prompt_ids: jnp.array = None,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
""" """
if jit: if jit:
images = _p_generate( images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
else: else:
images = self._generate( images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# TODO: maybe use a config dict instead of so many static argnums # TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate( def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
): ):
return pipe._generate( return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment