Unverified Commit a643c630 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[K Diffusion] Add k diffusion sampler natively (#1603)

* uP

* uP
parent 326de419
......@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
......@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.set_sampler("sample_euler")
pipe.set_scheduler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import importlib
import warnings
from typing import Callable, List, Optional, Union
import torch
......@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.k_diffusion_model = CompVisDenoiser(model)
def set_sampler(self, scheduler_type: str):
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
return self.set_scheduler(scheduler_type)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
......
......@@ -91,6 +91,7 @@ _deps = [
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"k-diffusion",
"librosa",
"modelcards>=0.1.4",
"numpy",
......@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"k-diffusion",
"librosa",
"parameterized",
"pytest",
......
......@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
from .utils import (
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
......@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
......
......@@ -15,6 +15,7 @@ deps = {
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion",
"librosa": "librosa",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
......
from ..utils import (
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_torch_available,
......@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
StableDiffusionOnnxPipeline,
)
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .stable_diffusion import StableDiffusionKDiffusionPipeline
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
......@@ -9,6 +9,7 @@ from PIL import Image
from ...utils import (
BaseOutput,
is_flax_available,
is_k_diffusion_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
......@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
......
......@@ -29,6 +29,7 @@ from .import_utils import (
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_librosa_available,
is_modelcards_available,
is_onnx_available,
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class StableDiffusionKDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "k_diffusion"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "k_diffusion"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "k_diffusion"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "k_diffusion"])
......@@ -210,6 +210,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
_k_diffusion_version = importlib_metadata.version("k_diffusion")
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False
def is_torch_available():
return _torch_available
......@@ -263,6 +270,10 @@ def is_accelerate_available():
return _accelerate_available
def is_k_diffusion_available():
return _k_diffusion_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
......@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
Unidecode`
"""
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR = """
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
BACKENDS_MAPPING = OrderedDict(
[
......@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
]
)
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import StableDiffusionKDiffusionPipeline
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
@slow
@require_torch_gpu
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_1(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.set_scheduler("sample_euler")
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_2(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.set_scheduler("sample_euler")
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.826810, 0.81958747, 0.8510199, 0.8376758, 0.83958465, 0.8682068, 0.84370345, 0.85251087, 0.85884345]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment