Unverified Commit 37d113cc authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

DiT Pipeline (#1806)



* added dit model

* import

* initial pipeline

* initial convert script

* initial pipeline

* make style

* raise valueerror

* single function

* rename classes

* use DDIMScheduler

* timesteps embedder

* samples to cpu

* fix var names

* fix numpy type

* use timesteps class for proj

* fix typo

* fix arg name

* flip_sin_to_cos and better var names

* fix C shape cal

* make style

* remove unused imports

* cleanup

* add back patch_size

* initial dit doc

* typo

* Update docs/source/api/pipelines/dit.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* added copyright license headers

* added example usage and toc

* fix variable names asserts

* remove comment

* added docs

* fix typo

* upstream changes

* set proper device for drop_ids

* added initial dit pipeline test

* update docs

* fix imports

* make fix-copies

* isort

* fix imports

* get rid of more magic numbers

* fix code when guidance is off

* remove block_kwargs

* cleanup script

* removed to_2tuple

* use FeedForward class instead of another MLP

* style

* work on mergint DiTBlock with BasicTransformerBlock

* added missing final_dropout and args to BasicTransformerBlock

* use norm from block

* fix arg

* remove unused arg

* fix call to class_embedder

* use timesteps

* make style

* attn_output gets multiplied

* removed commented code

* use Transformer2D

* use self.is_input_patches

* fix flags

* fixed conversion to use Transformer2DModel

* fixes for pipeline

* remove dit.py

* fix timesteps device

* use randn_tensor and fix fp16 inf.

* timesteps_emb already the right dtype

* fix dit test class

* fix test and style

* fix norm2 usage in vq-diffusion

* added author names to pipeline and lmagenet labels link

* fix tests

* use norm_type as string

* rename dit to transformer

* fix name

* fix test

* set  norm_type = "layer" by default

* fix tests

* do not skip common tests

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* revert AdaLayerNorm API

* fix norm_type name

* make sure all components are in eval mode

* revert norm2 API

* compact

* finish deprecation

* add slow tests

* remove @

* refactor some stuff

* upload

* Update src/diffusers/pipelines/dit/pipeline_dit.py

* finish more

* finish docs

* improve docs

* finish docs
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7e29b747
......@@ -21,8 +21,8 @@ from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
......@@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype
......
......@@ -21,8 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -92,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
......@@ -23,8 +23,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
......@@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype
pndm_order: int
......
......@@ -14,6 +14,7 @@
import importlib
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Union
import torch
......@@ -24,6 +25,21 @@ from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class KarrasDiffusionSchedulers(Enum):
DDIMScheduler = 1
DDPMScheduler = 2
PNDMScheduler = 3
LMSDiscreteScheduler = 4
EulerDiscreteScheduler = 5
HeunDiscreteScheduler = 6
EulerAncestralDiscreteScheduler = 7
DPMSolverMultistepScheduler = 8
DPMSolverSinglestepScheduler = 9
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
@dataclass
class SchedulerOutput(BaseOutput):
"""
......
......@@ -15,16 +15,24 @@ import importlib
import math
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
class FlaxKarrasDiffusionSchedulers(Enum):
FlaxDDIMScheduler = 1
FlaxDDPMScheduler = 2
FlaxPNDMScheduler = 3
FlaxLMSDiscreteScheduler = 4
FlaxDPMSolverMultistepScheduler = 5
@dataclass
......
......@@ -19,7 +19,6 @@ from packaging import version
from .. import __version__
from .constants import (
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CONFIG_NAME,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,
......
......@@ -30,18 +30,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler",
"KDPM2DiscreteScheduler",
"KDPM2AncestralDiscreteScheduler",
"DEISMultistepScheduler",
]
......@@ -227,6 +227,21 @@ class DiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DiTPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ImagePipelineOutput(metaclass=DummyObject):
_backends = ["torch"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
from diffusers.utils import load_numpy, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = DiTPipeline
test_cpu_offload = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
sample_size=4,
num_layers=2,
patch_size=2,
attention_head_dim=2,
num_attention_heads=2,
in_channels=4,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_elementwise_affine=False,
)
vae = AutoencoderKL()
scheduler = DDIMScheduler()
components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"class_labels": [1],
"generator": generator,
"num_inference_steps": 2,
"output_type": "numpy",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 4, 4, 3))
expected_slice = np.array(
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(relax_max_difference=True)
@require_torch_gpu
@slow
class DiTPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.to("cuda")
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images
for word, image in zip(words, images):
expected_image = load_numpy(
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"
)
assert np.abs((expected_image - image).sum()) < 1e-3
def test_dit_512_fp16(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images
for word, image in zip(words, images):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
f"/dit/{word}_fp16.npy"
)
assert np.abs((expected_image - image).sum()) < 1e-3
......@@ -36,7 +36,7 @@ class PipelineTesterMixin:
equivalence of dict and tuple outputs, etc.
"""
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image", "class_labels"]
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
num_inference_steps_args = ["num_inference_steps"]
......@@ -194,8 +194,8 @@ class PipelineTesterMixin:
):
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
# RePaint can hardly be made deterministic since the scheduler is currently always
# indeterministic
# CycleDiffusion is also slighly undeterministic
# nondeterministic
# CycleDiffusion is also slightly nondeterministic
return
if test_max_difference is None:
......@@ -515,7 +515,7 @@ class PipelineTesterMixin:
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forward_pass(self):
def test_xformers_attention_forwardGenerator_pass(self):
if not self.test_xformers_attention:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment