Unverified Commit c11d11d6 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[draft v2] AutoPipeline (#4138)



* initial

* style

* from ...pipelines -> from ..pipeline_util

* make style

* fix-copies

* fix value_guided_sampling oops

* style

* add test

* Show failing test

* update from_pipe

* fix

* add controlnet, additional test and register unused original config

* update for controlnet

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* store unused config as private attribute and pass if can

* add doc

* kandinsky inpaint pipeline does not work with decoder checkpoint

* update doc

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* style

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* Apply suggestions from code review

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d74561da
...@@ -26,9 +26,9 @@ from transformers import ( ...@@ -26,9 +26,9 @@ from transformers import (
) )
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
......
...@@ -255,6 +255,51 @@ class AudioPipelineOutput(metaclass=DummyObject): ...@@ -255,6 +255,51 @@ class AudioPipelineOutput(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoPipelineForImage2Image(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoPipelineForInpainting(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoPipelineForText2Image(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ConsistencyModelPipeline(metaclass=DummyObject): class ConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from collections import OrderedDict
import torch
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ControlNetModel,
)
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)
from diffusers.utils import slow
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
[
("stable-diffusion", "runwayml/stable-diffusion-v1-5"),
("if", "DeepFloyd/IF-I-XL-v1.0"),
("kandinsky", "kandinsky-community/kandinsky-2-1"),
("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"),
]
)
class AutoPipelineFastTest(unittest.TestCase):
def test_from_pipe_consistent(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
)
original_config = dict(pipe.config)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
pipe = AutoPipelineForText2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
def test_from_pipe_override(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
)
pipe = AutoPipelineForImage2Image.from_pipe(pipe, requires_safety_checker=True)
assert pipe.config.requires_safety_checker is True
pipe = AutoPipelineForText2Image.from_pipe(pipe, requires_safety_checker=True)
assert pipe.config.requires_safety_checker is True
def test_from_pipe_consistent_sdxl(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-xl-pipe",
requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False,
)
original_config = dict(pipe.config)
pipe = AutoPipelineForText2Image.from_pipe(pipe)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
def test_pipe_auto(self):
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
# test txt2img
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
if "kandinsky" not in model_name:
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_txt2img, pipe_to
gc.collect()
# test img2img
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
if "kandinsky" not in model_name:
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_img2img, pipe_to
gc.collect()
# test inpaint
if "kandinsky" not in model_name:
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_inpaint, pipe_to
gc.collect()
def test_from_pipe_consistent(self):
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
if model_name in ["kandinsky", "kandinsky22"]:
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image]
else:
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting]
# test from_pretrained
for pipe_from_class in auto_pipes:
pipe_from = pipe_from_class.from_pretrained(model_repo, variant="fp16", torch_dtype=torch.float16)
pipe_from_config = dict(pipe_from.config)
for pipe_to_class in auto_pipes:
pipe_to = pipe_to_class.from_pipe(pipe_from)
self.assertEqual(dict(pipe_to.config), pipe_from_config)
del pipe_from, pipe_to
gc.collect()
def test_controlnet(self):
# test from_pretrained
model_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/sd-controlnet-canny"
controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16)
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
# test from_pipe
for pipe_from in [pipe_txt2img, pipe_img2img, pipe_inpaint]:
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_txt2img.config))
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_img2img.config))
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_inpaint.config))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment