Unverified Commit 6a7f1f09 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Flax: avoid recompilation when params change (#1096)

* Do not recompile when guidance_scale changes.

* Remove debug for simplicity.

* make style

* Make guidance_scale an array.

* Make DEBUG a constant to avoid passing it down.

* Add comments for clarification.
parent 170ebd28
...@@ -42,6 +42,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker ...@@ -42,6 +42,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r""" r"""
...@@ -187,7 +190,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -187,7 +190,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width: Optional[int] = None, width: Optional[int] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None, latents: Optional[jnp.array] = None,
debug: bool = False,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.array = None,
): ):
# 0. Default height and width to unet # 0. Default height and width to unet
...@@ -260,8 +262,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -260,8 +262,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
if DEBUG:
if debug:
# run with python for loop # run with python for loop
for i in range(num_inference_steps): for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state)) latents, scheduler_state = loop_body(i, (latents, scheduler_state))
...@@ -283,11 +284,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -283,11 +284,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
guidance_scale: float = 7.5, guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None, latents: jnp.array = None,
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
debug: bool = False,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.array = None,
): ):
r""" r"""
...@@ -334,6 +334,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -334,6 +334,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(guidance_scale, float):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
if len(prompt_ids.shape) > 2:
# Assume sharded
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
if jit: if jit:
images = _p_generate( images = _p_generate(
self, self,
...@@ -345,7 +353,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -345,7 +353,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width, width,
guidance_scale, guidance_scale,
latents, latents,
debug,
neg_prompt_ids, neg_prompt_ids,
) )
else: else:
...@@ -358,7 +365,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -358,7 +365,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width, width,
guidance_scale, guidance_scale,
latents, latents,
debug,
neg_prompt_ids, neg_prompt_ids,
) )
...@@ -388,8 +394,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -388,8 +394,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
static_broadcasted_argnums=(0, 4, 5, 6),
)
def _p_generate( def _p_generate(
pipe, pipe,
prompt_ids, prompt_ids,
...@@ -400,7 +411,6 @@ def _p_generate( ...@@ -400,7 +411,6 @@ def _p_generate(
width, width,
guidance_scale, guidance_scale,
latents, latents,
debug,
neg_prompt_ids, neg_prompt_ids,
): ):
return pipe._generate( return pipe._generate(
...@@ -412,7 +422,6 @@ def _p_generate( ...@@ -412,7 +422,6 @@ def _p_generate(
width, width,
guidance_scale, guidance_scale,
latents, latents,
debug,
neg_prompt_ids, neg_prompt_ids,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment