Unverified Commit a062e47e authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add flax pipelines to api doc + doc string examples (#2600)



* add api doc for flax pipeline + doc string examples

* make style

---------
Co-authored-by: default avataryiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 75f1210a
...@@ -30,3 +30,7 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan ...@@ -30,3 +30,7 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
- disable_attention_slicing - disable_attention_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
- all
- __call__
\ No newline at end of file
...@@ -31,3 +31,7 @@ Available checkpoints are: ...@@ -31,3 +31,7 @@ Available checkpoints are:
- disable_attention_slicing - disable_attention_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
[[autodoc]] FlaxStableDiffusionInpaintPipeline
- all
- __call__
\ No newline at end of file
...@@ -39,3 +39,7 @@ Available Checkpoints are: ...@@ -39,3 +39,7 @@ Available Checkpoints are:
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- enable_vae_tiling - enable_vae_tiling
- disable_vae_tiling - disable_vae_tiling
[[autodoc]] FlaxStableDiffusionPipeline
- all
- __call__
...@@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate ...@@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
...@@ -33,7 +34,7 @@ from ...schedulers import ( ...@@ -33,7 +34,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -44,6 +45,39 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -44,6 +45,39 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging # Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False DEBUG = False
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import jax
>>> import numpy as np
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> from diffusers import FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> prng_seed = jax.random.PRNGKey(0)
>>> num_inference_steps = 50
>>> num_samples = jax.device_count()
>>> prompt = num_samples * [prompt]
>>> prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
>>> params = replicate(params)
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
>>> prompt_ids = shard(prompt_ids)
>>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
"""
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r""" r"""
...@@ -272,6 +306,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -272,6 +306,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image return image
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
...@@ -316,6 +351,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -316,6 +351,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple. a plain tuple.
Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
......
...@@ -23,6 +23,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -23,6 +23,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
...@@ -32,7 +33,7 @@ from ...schedulers import ( ...@@ -32,7 +33,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, logging from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -43,6 +44,64 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -43,6 +44,64 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging # Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False DEBUG = False
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import jax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionImg2ImgPipeline
>>> def create_key(seed=0):
... return jax.random.PRNGKey(seed)
>>> rng = create_key(0)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> response = requests.get(url)
>>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
>>> init_img = init_img.resize((768, 512))
>>> prompts = "A fantasy landscape, trending on artstation"
>>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4",
... revision="flax",
... dtype=jnp.bfloat16,
... )
>>> num_samples = jax.device_count()
>>> rng = jax.random.split(rng, jax.device_count())
>>> prompt_ids, processed_image = pipeline.prepare_inputs(
... prompt=[prompts] * num_samples, image=[init_img] * num_samples
... )
>>> p_params = replicate(params)
>>> prompt_ids = shard(prompt_ids)
>>> processed_image = shard(processed_image)
>>> output = pipeline(
... prompt_ids=prompt_ids,
... image=processed_image,
... params=p_params,
... prng_seed=rng,
... strength=0.75,
... num_inference_steps=50,
... jit=True,
... height=512,
... width=768,
... ).images
>>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```
"""
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
r""" r"""
...@@ -277,6 +336,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -277,6 +336,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image return image
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
...@@ -332,6 +392,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -332,6 +392,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
......
...@@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate ...@@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
...@@ -33,7 +34,7 @@ from ...schedulers import ( ...@@ -33,7 +34,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -44,6 +45,60 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -44,6 +45,60 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging # Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False DEBUG = False
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import jax
>>> import numpy as np
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> import PIL
>>> import requests
>>> from io import BytesIO
>>> from diffusers import FlaxStableDiffusionInpaintPipeline
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
>>> init_image = download_image(img_url).resize((512, 512))
>>> mask_image = download_image(mask_url).resize((512, 512))
>>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(
... "xvjiarui/stable-diffusion-2-inpainting"
... )
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
>>> prng_seed = jax.random.PRNGKey(0)
>>> num_inference_steps = 50
>>> num_samples = jax.device_count()
>>> prompt = num_samples * [prompt]
>>> init_image = num_samples * [init_image]
>>> mask_image = num_samples * [mask_image]
>>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(
... prompt, init_image, mask_image
... )
# shard inputs and rng
>>> params = replicate(params)
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
>>> prompt_ids = shard(prompt_ids)
>>> processed_masked_images = shard(processed_masked_images)
>>> processed_masks = shard(processed_masks)
>>> images = pipeline(
... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
... ).images
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
"""
class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r""" r"""
...@@ -332,6 +387,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -332,6 +387,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image return image
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
...@@ -378,6 +434,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -378,6 +434,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple. a plain tuple.
Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment