Unverified Commit da310757 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

updated doc for stable diffusion pipelines (#1770)

* add a doc page for each pipeline under api/pipelines/stable_diffusion
* add pipeline examples to docstrings
* updated stable_diffusion_2 page
* updated default markdown syntax to list methods based on https://github.com/huggingface/diffusers/pull/1870


* add function decorator
Co-authored-by: default avataryiyixuxu <yixu@Yis-MacBook-Pro.lan>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8c14ca3d
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
...@@ -33,7 +32,7 @@ from ...schedulers import ( ...@@ -33,7 +32,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -41,6 +40,34 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -41,6 +40,34 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import requests
>>> import torch
>>> from PIL import Image
>>> from io import BytesIO
>>> from diffusers import StableDiffusionImg2ImgPipeline
>>> device = "cuda"
>>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> response = requests.get(url)
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
>>> init_image = init_image.resize((768, 512))
>>> prompt = "A fantasy landscape, trending on artstation"
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
>>> images[0].save("fantasy_landscape.png")
```
"""
def preprocess(image): def preprocess(image):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
...@@ -455,6 +482,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -455,6 +482,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return latents return latents
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
...@@ -519,6 +547,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -519,6 +547,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
......
...@@ -616,6 +616,37 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -616,6 +616,37 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
Examples:
```py
>>> import PIL
>>> import requests
>>> import torch
>>> from io import BytesIO
>>> from diffusers import StableDiffusionInpaintPipeline
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
>>> init_image = download_image(img_url).resize((512, 512))
>>> mask_image = download_image(mask_url).resize((512, 512))
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
>>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
```
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
......
...@@ -390,6 +390,32 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -390,6 +390,32 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
Examples:
```py
>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> from diffusers import StableDiffusionUpscalePipeline
>>> import torch
>>> # load model and scheduler
>>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
>>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
... model_id, revision="fp16", torch_dtype=torch.float16
... )
>>> pipeline = pipeline.to("cuda")
>>> # let's download an image
>>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
>>> response = requests.get(url)
>>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
>>> low_res_img = low_res_img.resize((128, 128))
>>> prompt = "a white cat"
>>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
>>> upscaled_image.save("upsampled_cat.png")
```
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
......
...@@ -60,6 +60,7 @@ class EMAModel: ...@@ -60,6 +60,7 @@ class EMAModel:
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps). at 215.4k steps).
Args: Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3. power (float): Exponential factor of EMA warmup. Default: 2/3.
......
...@@ -32,6 +32,7 @@ from .constants import ( ...@@ -32,6 +32,7 @@ from .constants import (
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent from .hub_utils import HF_HUB_OFFLINE, http_user_agent
from .import_utils import ( from .import_utils import (
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Doc utilities: Utilities related to documentation
"""
import re
def replace_example_docstring(example_docstring):
def docstring_decorator(fn):
func_doc = fn.__doc__
lines = func_doc.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
lines[i] = example_docstring
func_doc = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, "
f"current docstring is:\n{func_doc}"
)
fn.__doc__ = func_doc
return fn
return docstring_decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment