Unverified Commit c11d11d6 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[draft v2] AutoPipeline (#4138)



* initial

* style

* from ...pipelines -> from ..pipeline_util

* make style

* fix-copies

* fix value_guided_sampling oops

* style

* add test

* Show failing test

* update from_pipe

* fix

* add controlnet, additional test and register unused original config

* update for controlnet

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* store unused config as private attribute and pass if can

* add doc

* kandinsky inpaint pipeline does not work with decoder checkpoint

* update doc

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* style

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* Apply suggestions from code review

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d74561da
......@@ -182,6 +182,8 @@
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoPipeline
In many cases, one checkpoint can be used for multiple tasks. For example, you may be able to use the same checkpoint for Text-to-Image, Image-to-Image, and Inpainting. However, you'll need to know the pipeline class names linked to your checkpoint.
AutoPipeline is designed to make it easy for you to use multiple pipelines in your workflow. We currently provide 3 AutoPipeline classes to perform three different tasks, i.e. [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]. You'll need to choose the AutoPipeline class based on the task you want to perform and use it to automatically retrieve the relevant pipeline given the name/path to the pre-trained weights.
For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
```python
from diffusers import PipelineForImageToImage
pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
```
It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do
```python
from diffusers import PipelineForInpainting
pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_i2i)
```
All the components will be transferred to the inpainting pipeline with zero cost.
Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting tasks for below diffusion models:
- [stable Diffusion](./stable_diffusion)
- [Stable Diffusion Controlnet](./api/pipelines/controlnet)
- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl)
- [IF](./if)
- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)
- [Kandinsky 2.2]()(./kandinsky)
## AutoPipelineForText2Image
[[autodoc]] AutoPipelineForText2Image
- all
- from_pretrained
- from_pipe
## AutoPipelineForImage2Image
[[autodoc]] AutoPipelineForImage2Image
- all
- from_pretrained
- from_pipe
## AutoPipelineForInpainting
[[autodoc]] AutoPipelineForInpainting
- all
- from_pretrained
- from_pipe
......@@ -62,6 +62,9 @@ else:
)
from .pipelines import (
AudioPipelineOutput,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ConsistencyModelPipeline,
DanceDiffusionPipeline,
DDIMPipeline,
......
......@@ -17,6 +17,7 @@ try:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
......
This diff is collapsed.
......@@ -20,8 +20,6 @@ from transformers import (
)
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -30,6 +28,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
......
......@@ -23,8 +23,6 @@ from transformers import (
)
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler
from ...utils import (
is_accelerate_available,
......@@ -33,6 +31,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
......
......@@ -25,8 +25,6 @@ from transformers import (
)
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDIMScheduler
from ...utils import (
is_accelerate_available,
......@@ -35,6 +33,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
......
......@@ -21,7 +21,6 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
BaseOutput,
......@@ -29,6 +28,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -17,8 +17,6 @@ from typing import List, Optional, Union
import torch
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -27,6 +25,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -17,8 +17,6 @@ from typing import List, Optional, Union
import torch
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -27,6 +25,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -20,8 +20,6 @@ import torch
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -30,6 +28,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -20,8 +20,6 @@ import torch
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -30,6 +28,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -22,8 +22,6 @@ import torch.nn.functional as F
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
......@@ -32,6 +30,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -5,7 +5,6 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
logging,
......@@ -13,6 +12,7 @@ from ...utils import (
replace_example_docstring,
)
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -5,7 +5,6 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import UnCLIPScheduler
from ...utils import (
logging,
......@@ -13,6 +12,7 @@ from ...utils import (
replace_example_docstring,
)
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -22,7 +22,6 @@ import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
......@@ -32,6 +31,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer
......
......@@ -21,7 +21,6 @@ import torch
from transformers import CLIPImageProcessor, CLIPVisionModel
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
......@@ -29,6 +28,7 @@ from ...utils import (
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer
......
......@@ -23,9 +23,9 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
......
......@@ -21,10 +21,9 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment