Unverified Commit 4c483deb authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[refactor embeddings] gligen + ip-adapter (#6244)



* refactor ip-adapter-imageproj, gligen

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 1ac07d8a
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import nn from torch import nn
from ..models.embeddings import ImageProjection, MLPProjection, Resampler from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin: ...@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0] cross_attention_dim = state_dict["proj.3.weight"].shape[0]
image_projection = MLPProjection( image_projection = IPAdapterFullImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
) )
...@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin: ...@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
hidden_dims = state_dict["latents"].shape[2] hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler( image_projection = IPAdapterPlusImageProjection(
embed_dims=embed_dims, embed_dims=embed_dims,
output_dims=output_dims, output_dims=output_dims,
hidden_dims=hidden_dims, hidden_dims=hidden_dims,
...@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin: ...@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights, # Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`. # because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict # set ip-adapter cross-attention processors & load state_dict
......
...@@ -462,7 +462,7 @@ class ImageProjection(nn.Module): ...@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
return image_embeds return image_embeds
class MLPProjection(nn.Module): class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__() super().__init__()
from .attention import FeedForward from .attention import FeedForward
...@@ -621,29 +621,34 @@ class AttentionPooling(nn.Module): ...@@ -621,29 +621,34 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token return a[:, 0, :] # cls_token
class FourierEmbedder(nn.Module): def get_fourier_embeds_from_boundingbox(embed_dim, box):
def __init__(self, num_freqs=64, temperature=100): """
super().__init__() Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
batch_size, num_boxes = box.shape[:2]
self.num_freqs = num_freqs emb = 100 ** (torch.arange(embed_dim) / embed_dim)
self.temperature = temperature emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
freq_bands = freq_bands[None, None, None] emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
self.register_buffer("freq_bands", freq_bands, persistent=False)
def __call__(self, x): return emb
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class PositionNet(nn.Module): class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__() super().__init__()
self.positive_len = positive_len self.positive_len = positive_len
self.out_dim = out_dim self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple): if isinstance(out_dim, tuple):
...@@ -692,7 +697,7 @@ class PositionNet(nn.Module): ...@@ -692,7 +697,7 @@ class PositionNet(nn.Module):
masks = masks.unsqueeze(-1) masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder) # embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
# learnable null embedding # learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1) xyxy_null = self.null_position_feature.view(1, 1, -1)
...@@ -787,7 +792,7 @@ class PixArtAlphaTextProjection(nn.Module): ...@@ -787,7 +792,7 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states return hidden_states
class Resampler(nn.Module): class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus. """Resampler of IP-Adapter Plus.
Args: Args:
......
...@@ -32,10 +32,10 @@ from .attention_processor import ( ...@@ -32,10 +32,10 @@ from .attention_processor import (
) )
from .embeddings import ( from .embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding, ImageHintTimeEmbedding,
ImageProjection, ImageProjection,
ImageTimeEmbedding, ImageTimeEmbedding,
PositionNet,
TextImageProjection, TextImageProjection,
TextImageTimeEmbedding, TextImageTimeEmbedding,
TextTimeEmbedding, TextTimeEmbedding,
...@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image" feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet( self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
) )
......
...@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module): ...@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class PositionNet(nn.Module): class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__() super().__init__()
self.positive_len = positive_len self.positive_len = positive_len
...@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image" feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet( self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
) )
......
...@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): ...@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
) )
gligen_phrases = gligen_phrases[:max_objs] gligen_phrases = gligen_phrases[:max_objs]
gligen_boxes = gligen_boxes[:max_objs] gligen_boxes = gligen_boxes[:max_objs]
# prepare batched input to the PositionNet (boxes, phrases, mask) # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
# Get tokens for phrases from pre-trained CLIPTokenizer # Get tokens for phrases from pre-trained CLIPTokenizer
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
# For the token, we use the same pre-trained text encoder # For the token, we use the same pre-trained text encoder
......
...@@ -26,7 +26,7 @@ from pytest import mark ...@@ -26,7 +26,7 @@ from pytest import mark
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection, Resampler from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model): ...@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
# "image_proj" (ImageProjection layer weights) # "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"] cross_attention_dim = model.config["cross_attention_dim"]
image_projection = Resampler( image_projection = IPAdapterPlusImageProjection(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment