Unverified Commit da5ab51d authored by nikhil-masterful's avatar nikhil-masterful Committed by GitHub
Browse files

Add GLIGEN implementation (#4441)

* Add GLIGEN implementation

* GLIGEN: Fix code quality check failures

* GLIGEN: Fix Import block un-sorted or un-formatted failures

* GLIGEN: Fix check_repository_consistency failures

* GLIGEN: Add 'PositionNet' to versatile_diffusion/modeling_text_unet.py

* GLIGEN: check_repository_consistency: fix 'copy does not match' error

* GLIGEN: Fix review comments (1)

* GLIGEN: Fix E721 Do not compare types, use `isinstance()` failures

* GLIGEN : Ensure _encode_prompt() copy matches to StableDiffusionPipeline

* GLIGEN: Fix ruff E721 failure in unidiffuser/test_unidiffuser.py

* GLIGEN: doc_builder: restyle pipeline_stable_diffusion_gligen.py

* GIGLEN: reset files unrelated to gligen

* GLIGEN: Fix documentation comments (1)

* GLIGEN: Fix review comments (2)

* GLIGEN: Added FastTest

* GLIGEN: Fix review comments (3)
parent 5175d3d7
......@@ -267,6 +267,8 @@
title: LDM3D Text-to-(RGB, Depth)
- local: api/pipelines/stable_diffusion/adapter
title: Stable Diffusion T2I-adapter
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
......
<!--Copyright 2023 The GLIGEN Authors and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes, if input images are given, this pipeline can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:
*Large-scale text-to-image diffusion models have made amazing advances. However, the status quo is to use text input alone, which can impede controllability. In this work, we propose GLIGEN, Grounded-Language-to-Image Generation, a novel approach that builds upon and extends the functionality of existing pre-trained text-to-image diffusion models by enabling them to also be conditioned on grounding inputs. To preserve the vast concept knowledge of the pre-trained model, we freeze all of its weights and inject the grounding information into new trainable layers via a gated mechanism. Our model achieves open-world grounded text2img generation with caption and bounding box condition inputs, and the grounding ability generalizes well to novel spatial configurations and concepts. GLIGEN’s zeroshot performance on COCO and LVIS outperforms existing supervised layout-to-image baselines by a large margin.*
<Tip>
Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!
If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!
</Tip>
This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful).
## StableDiffusionGLIGENPipeline
[[autodoc]] StableDiffusionGLIGENPipeline
- all
- __call__
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling
- enable_model_cpu_offload
- prepare_latents
- enable_fuser
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
......@@ -171,6 +171,7 @@ else:
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
......
......@@ -24,6 +24,38 @@ from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x, objs):
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
......@@ -62,6 +94,7 @@ class BasicTransformerBlock(nn.Module):
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
):
super().__init__()
self.only_cross_attention = only_cross_attention
......@@ -120,6 +153,10 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 4. Fuser
if attention_type == "gated":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
......@@ -150,7 +187,9 @@ class BasicTransformerBlock(nn.Module):
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 0. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
......@@ -162,6 +201,11 @@ class BasicTransformerBlock(nn.Module):
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 1.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 1.5 ends
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
......
......@@ -544,3 +544,59 @@ class AttentionPooling(nn.Module):
a = a.reshape(bs, -1, 1).transpose(1, 2)
return a[:, 0, :] # cls_token
class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
self.num_freqs = num_freqs
self.temperature = temperature
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
return objs
......@@ -91,6 +91,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
attention_type: str = "default",
):
super().__init__()
self.use_linear_projection = use_linear_projection
......@@ -183,6 +184,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
attention_type=attention_type,
)
for d in range(num_layers)
]
......
......@@ -49,6 +49,7 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
attention_type="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
......@@ -129,6 +130,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
)
elif down_block_type == "SimpleCrossAttnDownBlock2D":
if cross_attention_dim is None:
......@@ -244,6 +246,7 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
attention_type="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
......@@ -307,6 +310,7 @@ def get_up_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
)
elif up_block_type == "SimpleCrossAttnUpBlock2D":
if cross_attention_dim is None:
......@@ -556,6 +560,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
......@@ -592,6 +597,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......@@ -934,6 +940,7 @@ class CrossAttnDownBlock2D(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
......@@ -970,6 +977,7 @@ class CrossAttnDownBlock2D(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......@@ -2068,6 +2076,7 @@ class CrossAttnUpBlock2D(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
......@@ -2106,6 +2115,7 @@ class CrossAttnUpBlock2D(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......
......@@ -28,6 +28,7 @@ from .embeddings import (
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
......@@ -198,6 +199,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
......@@ -446,6 +448,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
......@@ -469,6 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
......@@ -535,6 +539,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
......@@ -560,6 +565,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
if attention_type == "gated":
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
......@@ -895,6 +908,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
......
......@@ -90,6 +90,7 @@ else:
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
......
......@@ -45,6 +45,7 @@ else:
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
......
......@@ -153,6 +153,62 @@ def get_up_block(
raise ValueError(f"{up_block_type} is not supported.")
class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
self.num_freqs = num_freqs
self.temperature = temperature
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
return objs
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
......@@ -298,6 +354,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
......@@ -556,6 +613,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
......@@ -579,6 +637,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
......@@ -645,6 +704,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
......@@ -670,6 +730,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
if attention_type == "gated":
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
......@@ -1012,6 +1080,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
......@@ -1331,6 +1405,7 @@ class CrossAttnDownBlockFlat(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
......@@ -1367,6 +1442,7 @@ class CrossAttnDownBlockFlat(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......@@ -1569,6 +1645,7 @@ class CrossAttnUpBlockFlat(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
......@@ -1607,6 +1684,7 @@ class CrossAttnUpBlockFlat(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......@@ -1711,6 +1789,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
......@@ -1747,6 +1826,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
......
......@@ -602,6 +602,21 @@ class StableDiffusionDiffEditPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionGLIGENPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionGLIGENPipeline,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class GligenPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionGLIGENPipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
attention_type="gated",
)
# unet.position_net = PositionNet(32,32)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A modern livingroom",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"gligen_phrases": ["a birthday cake"],
"gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
"output_type": "np",
}
return inputs
def test_gligen(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionGLIGENPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment