Unverified Commit a643c630 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[K Diffusion] Add k diffusion sampler natively (#1603)

* uP

* uP
parent 326de419
...@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom ...@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe = pipe.to("cuda") pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars" prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun") pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed) generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
...@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom ...@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda") pipe = pipe.to("cuda")
pipe.set_sampler("sample_euler") pipe.set_scheduler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed) generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
``` ```
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
...@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.k_diffusion_model = CompVisDenoiser(model) self.k_diffusion_model = CompVisDenoiser(model)
def set_sampler(self, scheduler_type: str): def set_sampler(self, scheduler_type: str):
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
return self.set_scheduler(scheduler_type)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion") library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling") sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type) self.sampler = getattr(sampling, scheduler_type)
......
...@@ -91,6 +91,7 @@ _deps = [ ...@@ -91,6 +91,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65", "jaxlib>=0.1.65",
"k-diffusion",
"librosa", "librosa",
"modelcards>=0.1.4", "modelcards>=0.1.4",
"numpy", "numpy",
...@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder") ...@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list( extras["test"] = deps_list(
"datasets", "datasets",
"k-diffusion",
"librosa", "librosa",
"parameterized", "parameterized",
"pytest", "pytest",
......
...@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel ...@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
from .utils import ( from .utils import (
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available,
is_onnx_available, is_onnx_available,
is_scipy_available, is_scipy_available,
is_torch_available, is_torch_available,
...@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available(): ...@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
else: else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_onnx_available(): if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import ( from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
......
...@@ -15,6 +15,7 @@ deps = { ...@@ -15,6 +15,7 @@ deps = {
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65", "jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion",
"librosa": "librosa", "librosa": "librosa",
"modelcards": "modelcards>=0.1.4", "modelcards": "modelcards>=0.1.4",
"numpy": "numpy", "numpy": "numpy",
......
from ..utils import ( from ..utils import (
is_flax_available, is_flax_available,
is_k_diffusion_available,
is_librosa_available, is_librosa_available,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
...@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available(): ...@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
StableDiffusionOnnxPipeline, StableDiffusionOnnxPipeline,
) )
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .stable_diffusion import StableDiffusionKDiffusionPipeline
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline from .stable_diffusion import FlaxStableDiffusionPipeline
...@@ -9,6 +9,7 @@ from PIL import Image ...@@ -9,6 +9,7 @@ from PIL import Image
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
is_flax_available, is_flax_available,
is_k_diffusion_available,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
...@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers ...@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
else: else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
if is_transformers_available() and is_onnx_available(): if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
......
...@@ -29,6 +29,7 @@ from .import_utils import ( ...@@ -29,6 +29,7 @@ from .import_utils import (
is_accelerate_available, is_accelerate_available,
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available,
is_librosa_available, is_librosa_available,
is_modelcards_available, is_modelcards_available,
is_onnx_available, is_onnx_available,
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class StableDiffusionKDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "k_diffusion"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "k_diffusion"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "k_diffusion"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "k_diffusion"])
...@@ -210,6 +210,13 @@ try: ...@@ -210,6 +210,13 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_xformers_available = False _xformers_available = False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
_k_diffusion_version = importlib_metadata.version("k_diffusion")
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
...@@ -263,6 +270,10 @@ def is_accelerate_available(): ...@@ -263,6 +270,10 @@ def is_accelerate_available():
return _accelerate_available return _accelerate_available
def is_k_diffusion_available():
return _k_diffusion_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """ ...@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
Unidecode` Unidecode`
""" """
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR = """
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
...@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
] ]
) )
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import StableDiffusionKDiffusionPipeline
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
@slow
@require_torch_gpu
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_1(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.set_scheduler("sample_euler")
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_2(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.set_scheduler("sample_euler")
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.826810, 0.81958747, 0.8510199, 0.8376758, 0.83958465, 0.8682068, 0.84370345, 0.85251087, 0.85884345]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment