Unverified Commit 38466c36 authored by Nguyễn Công Tú Anh's avatar Nguyễn Công Tú Anh Committed by GitHub
Browse files

Add GLIGEN Text Image implementation (#4777)

* Add GLIGEN Text Image implementation

* add style transfer from image

* fix check_repository_consistency

* add convert script GLIGEN model to Diffusers

* rename attention type

* fix style code

* remove PositionNetTextImage

* Revert "fix check_repository_consistency"

This reverts commit 15f098c96e00bb9e67b831161615b30a2d28d815.

* change attention type name

* update docs for GLIGEN

* change examples with hf-document-image

* fix style

* add CLIPImageProjection for GLIGEN

* Add new encode_prompt, load project matrix in pipe init

* move CLIPImageProjection to stable_diffusion

* add comment
parent 5f740d0f
...@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. ...@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# GLIGEN (Grounded Language-to-Image Generation) # GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes, if input images are given, this pipeline can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs. The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
The abstract from the [paper](https://huggingface.co/papers/2301.07093) is: The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:
...@@ -26,7 +26,7 @@ If you want to use one of the official checkpoints for a task, explore the [glig ...@@ -26,7 +26,7 @@ If you want to use one of the official checkpoints for a task, explore the [glig
</Tip> </Tip>
This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful). [`StableDiffusionGLIGENPipeline`] was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful) and [`StableDiffusionGLIGENTextImagePipeline`] was contributed by [Nguyễn Công Tú Anh](https://github.com/tuanh123789).
## StableDiffusionGLIGENPipeline ## StableDiffusionGLIGENPipeline
...@@ -41,6 +41,19 @@ This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhi ...@@ -41,6 +41,19 @@ This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhi
- prepare_latents - prepare_latents
- enable_fuser - enable_fuser
## StableDiffusionGLIGENTextImagePipeline
[[autodoc]] StableDiffusionGLIGENTextImagePipeline
- all
- __call__
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling
- enable_model_cpu_offload
- prepare_latents
- enable_fuser
## StableDiffusionPipelineOutput ## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput [[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
This diff is collapsed.
...@@ -66,6 +66,7 @@ else: ...@@ -66,6 +66,7 @@ else:
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
AutoPipelineForInpainting, AutoPipelineForInpainting,
AutoPipelineForText2Image, AutoPipelineForText2Image,
CLIPImageProjection,
ConsistencyModelPipeline, ConsistencyModelPipeline,
DanceDiffusionPipeline, DanceDiffusionPipeline,
DDIMPipeline, DDIMPipeline,
...@@ -176,6 +177,7 @@ else: ...@@ -176,6 +177,7 @@ else:
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline, StableDiffusionGLIGENPipeline,
StableDiffusionGLIGENTextImagePipeline,
StableDiffusionImageVariationPipeline, StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
......
...@@ -154,7 +154,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -154,7 +154,7 @@ class BasicTransformerBlock(nn.Module):
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 4. Fuser # 4. Fuser
if attention_type == "gated": if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# let chunk size default to None # let chunk size default to None
......
...@@ -563,7 +563,7 @@ class FourierEmbedder(nn.Module): ...@@ -563,7 +563,7 @@ class FourierEmbedder(nn.Module):
class PositionNet(nn.Module): class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__() super().__init__()
self.positive_len = positive_len self.positive_len = positive_len
self.out_dim = out_dim self.out_dim = out_dim
...@@ -573,30 +573,83 @@ class PositionNet(nn.Module): ...@@ -573,30 +573,83 @@ class PositionNet(nn.Module):
if isinstance(out_dim, tuple): if isinstance(out_dim, tuple):
out_dim = out_dim[0] out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings): def forward(
self,
boxes,
masks,
positive_embeddings=None,
phrases_masks=None,
image_masks=None,
phrases_embeddings=None,
image_embeddings=None,
):
masks = masks.unsqueeze(-1) masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder) # embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# learnable null embedding # learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1) xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding # replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) # positionet with text only information
if positive_embeddings is not None:
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
# positionet with text and image infomation
else:
phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1)
# learnable null embedding
text_null = self.null_text_feature.view(1, 1, -1)
image_null = self.null_image_feature.view(1, 1, -1)
# replace padding with learnable null embedding
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
objs = torch.cat([objs_text, objs_image], dim=1)
return objs return objs
...@@ -565,13 +565,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -565,13 +565,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
if attention_type == "gated": if attention_type in ["gated", "gated-text-image"]:
positive_len = 768 positive_len = 768
if isinstance(cross_attention_dim, int): if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
......
...@@ -94,6 +94,7 @@ else: ...@@ -94,6 +94,7 @@ else:
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline, StableDiffusionGLIGENPipeline,
StableDiffusionGLIGENTextImagePipeline,
StableDiffusionImageVariationPipeline, StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
...@@ -111,6 +112,7 @@ else: ...@@ -111,6 +112,7 @@ else:
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
) )
from .stable_diffusion.clip_image_project_model import CLIPImageProjection
from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_xl import ( from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
......
...@@ -42,10 +42,12 @@ try: ...@@ -42,10 +42,12 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .clip_image_project_model import CLIPImageProjection
from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
......
# Copyright 2023 The GLIGEN Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class CLIPImageProjection(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, hidden_size: int = 768):
super().__init__()
self.hidden_size = hidden_size
self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def forward(self, x):
return self.project(x)
...@@ -170,7 +170,7 @@ class FourierEmbedder(nn.Module): ...@@ -170,7 +170,7 @@ class FourierEmbedder(nn.Module):
class PositionNet(nn.Module): class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8): def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__() super().__init__()
self.positive_len = positive_len self.positive_len = positive_len
self.out_dim = out_dim self.out_dim = out_dim
...@@ -180,32 +180,72 @@ class PositionNet(nn.Module): ...@@ -180,32 +180,72 @@ class PositionNet(nn.Module):
if isinstance(out_dim, tuple): if isinstance(out_dim, tuple):
out_dim = out_dim[0] out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings): def forward(
self,
boxes,
masks,
positive_embeddings=None,
phrases_masks=None,
image_masks=None,
phrases_embeddings=None,
image_embeddings=None,
):
masks = masks.unsqueeze(-1) masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder) xyxy_embedding = self.fourier_embedder(boxes)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1) xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) if positive_embeddings:
positive_null = self.null_positive_feature.view(1, 1, -1)
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
else:
phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1)
text_null = self.null_text_feature.view(1, 1, -1)
image_null = self.null_image_feature.view(1, 1, -1)
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
objs = torch.cat([objs_text, objs_image], dim=1)
return objs return objs
...@@ -730,13 +770,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -730,13 +770,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
if attention_type == "gated": if attention_type in ["gated", "gated-text-image"]:
positive_len = 768 positive_len = 768
if isinstance(cross_attention_dim, int): if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
......
...@@ -315,6 +315,21 @@ class AutoPipelineForText2Image(metaclass=DummyObject): ...@@ -315,6 +315,21 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ConsistencyModelPipeline(metaclass=DummyObject): class ConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -677,6 +677,21 @@ class StableDiffusionGLIGENPipeline(metaclass=DummyObject): ...@@ -677,6 +677,21 @@ class StableDiffusionGLIGENPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusionGLIGENTextImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionImageVariationPipeline(metaclass=DummyObject): class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import (
CLIPProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import (
AutoencoderKL,
CLIPImageProjection,
DDIMScheduler,
StableDiffusionGLIGENTextImagePipeline,
UNet2DConditionModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class GligenTextImagePipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionGLIGENTextImagePipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
attention_type="gated-text-image",
)
# unet.position_net = PositionNet(32,32)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
projection_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
image_project = CLIPImageProjection(hidden_size=32)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": image_encoder,
"image_project": image_project,
"processor": processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
gligen_images = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png"
)
inputs = {
"prompt": "A modern livingroom",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"gligen_phrases": ["a birthday cake"],
"gligen_images": [gligen_images],
"gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
"output_type": "np",
}
return inputs
def test_gligen(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment