Commit 2f158a7d authored by ai_public's avatar ai_public
Browse files

ip-adapter

parents
Pipeline #1575 canceled with stages
This diff is collapsed.
This diff is collapsed.
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.attention_processor import Attention
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
class IPJointAttnProcessor2_0(torch.nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, context_dim, hidden_dim, scale=1.0):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
super().__init__()
self.scale = scale
self.add_k_proj_ip = nn.Linear(context_dim, hidden_dim)
self.add_v_proj_ip = nn.Linear(context_dim, hidden_dim)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
ip_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
sample_query = query # latent query
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# for ip-adapter
ip_key = self.add_k_proj_ip(ip_hidden_states)
ip_value = self.add_v_proj_ip(ip_hidden_states)
ip_query = sample_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_hidden_states = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(ip_query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
import torch
from resampler import Resampler
from transformers import CLIPVisionModel
BATCH_SIZE = 2
OUTPUT_DIM = 1280
NUM_QUERIES = 8
NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
def main():
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
embedding_dim = image_encoder.config.hidden_size
print(f"image_encoder hidden size: ", embedding_dim)
image_proj_model = Resampler(
dim=1024,
depth=2,
dim_head=64,
heads=16,
num_queries=NUM_QUERIES,
embedding_dim=embedding_dim,
output_dim=OUTPUT_DIM,
ff_mult=2,
max_seq_len=257,
apply_pos_emb=APPLY_POS_EMB,
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
)
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
with torch.no_grad():
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
print("image_embds shape: ", image_embeds.shape)
with torch.no_grad():
ip_tokens = image_proj_model(image_embeds)
print("ip_tokens shape:", ip_tokens.shape)
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
if __name__ == "__main__":
main()
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split('.')[-1].startswith('attn2'):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1,0)
temp_size = None
for i in range(0,5):
scale = 2 ** i
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
#total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
#total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
#print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
#image = fix_save_attn_map(attn_map)
images.append(image)
#print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None
return generator
\ No newline at end of file
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "411c59b3-f177-4a10-8925-d931ce572eaa",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL\n",
"from PIL import Image\n",
"\n",
"from ip_adapter import IPAdapter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"current_dir = os.getcwd()\n",
"print(current_dir)\n",
"\n",
"base_model_path = f\"{current_dir}/pretrained_models/sd1.5/Realistic_Vision_v4.0_noVAE\"\n",
"vae_model_path = f\"{current_dir}/pretrained_models/sd1.5/sd-vae-ft-mse\"\n",
"image_encoder_path = f\"{current_dir}/pretrained_models/models/image_encoder/\"\n",
"ip_ckpt = f\"{current_dir}/pretrained_models/models/ip-adapter_sd15.safetensors\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63ec542f-8474-4f38-9457-073425578073",
"metadata": {},
"outputs": [],
"source": [
"def image_grid(imgs, rows, cols):\n",
" assert len(imgs) == rows*cols\n",
"\n",
" w, h = imgs[0].size\n",
" grid = Image.new('RGB', size=(cols*w, rows*h))\n",
" grid_w, grid_h = grid.size\n",
" \n",
" for i, img in enumerate(imgs):\n",
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
" return grid\n",
"\n",
"noise_scheduler = DDIMScheduler(\n",
" num_train_timesteps=1000,\n",
" beta_start=0.00085,\n",
" beta_end=0.012,\n",
" beta_schedule=\"scaled_linear\",\n",
" clip_sample=False,\n",
" set_alpha_to_one=False,\n",
" steps_offset=1,\n",
")\n",
"vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)"
]
},
{
"cell_type": "markdown",
"id": "d8081d92-8f42-4bcd-9f83-44aec3f549a9",
"metadata": {},
"source": [
"## Image Variations"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3849f9d0-5f68-4a49-9190-69dd50720cae",
"metadata": {},
"outputs": [],
"source": [
"# load SD pipeline\n",
"pipe = StableDiffusionPipeline.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec09e937-3904-4d8e-a559-9066502ded36",
"metadata": {},
"outputs": [],
"source": [
"# read image prompt\n",
"image = Image.open(\"assets/images/woman.png\")\n",
"image.resize((256, 256))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81b1ab06-d3ed-4a7e-a356-9ddf1a2eecd6",
"metadata": {},
"outputs": [],
"source": [
"# load ip-adapter\n",
"ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b77f52de-a9e4-44e1-aeec-8165414f1273",
"metadata": {},
"outputs": [],
"source": [
"# generate image variations\n",
"images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42)\n",
"grid = image_grid(images, 1, 4)\n",
"grid"
]
},
{
"cell_type": "markdown",
"id": "cf199405-7cb5-4f78-9973-5fe51c632a41",
"metadata": {},
"source": [
"## Image-to-Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f089ad0-4683-46d7-ab58-9e5fe8f34c67",
"metadata": {},
"outputs": [],
"source": [
"# load SD Img2Img pipe\n",
"del pipe, ip_model\n",
"torch.cuda.empty_cache()\n",
"pipe = StableDiffusionImg2ImgPipeline.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8db2b55-2f56-4eef-b2ca-c5126b14feb7",
"metadata": {},
"outputs": [],
"source": [
"# read image prompt\n",
"image = Image.open(\"assets/images/river.png\")\n",
"g_image = Image.open(\"assets/images/vermeer.jpg\")\n",
"image_grid([image.resize((256, 256)), g_image.resize((256, 256))], 1, 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a501f284-f295-4673-96ab-e34378da62ab",
"metadata": {},
"outputs": [],
"source": [
"# load ip-adapter\n",
"ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f58fff74-9ff2-46e6-bc8a-2ad4ae1fbe0f",
"metadata": {},
"outputs": [],
"source": [
"# generate\n",
"images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42, image=g_image, strength=0.6)\n",
"grid = image_grid(images, 1, 4)\n",
"grid"
]
},
{
"cell_type": "markdown",
"id": "420a7c45-8697-411f-8374-3c81d5d972e3",
"metadata": {},
"source": [
"## Inpainting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "385cb339-3326-4523-a7db-b09e62d39c80",
"metadata": {},
"outputs": [],
"source": [
"# load SD Inpainting pipe\n",
"del pipe, ip_model\n",
"torch.cuda.empty_cache()\n",
"pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c47f8ce5-eed0-41ef-9dbb-2272ec4bc224",
"metadata": {},
"outputs": [],
"source": [
"# read image prompt\n",
"image = Image.open(\"assets/images/girl.png\")\n",
"image.resize((256, 256))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9b77289-65f5-459b-ada5-5c7c265bb4a6",
"metadata": {},
"outputs": [],
"source": [
"masked_image = Image.open(\"assets/inpainting/image.png\").resize((512, 768))\n",
"mask = Image.open(\"assets/inpainting/mask.png\").resize((512, 768))\n",
"image_grid([masked_image.resize((256, 384)), mask.resize((256, 384))], 1, 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e49dbdaa-58eb-4bcf-acab-fa5e08f96dcb",
"metadata": {},
"outputs": [],
"source": [
"# load ip-adapter\n",
"ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "945f6800-18b8-4d95-9f5e-e7035166cbbd",
"metadata": {},
"outputs": [],
"source": [
"# generate\n",
"images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50,\n",
" seed=42, image=masked_image, mask_image=mask, strength=0.7, )\n",
"grid = image_grid(images, 1, 4)\n",
"grid"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "411c59b3-f177-4a10-8925-d931ce572eaa",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from diffusers import StableDiffusionAdapterPipeline, T2IAdapter, DDIMScheduler, AutoencoderKL\n",
"from PIL import Image\n",
"\n",
"from ip_adapter import IPAdapter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"current_dir = os.getcwd()\n",
"print(current_dir)\n",
"\n",
"base_model_path = f\"{current_dir}/pretrained_models/sd1.5/Realistic_Vision_v4.0_noVAE\"\n",
"vae_model_path = f\"{current_dir}/pretrained_models/sd1.5/sd-vae-ft-mse\"\n",
"image_encoder_path = f\"{current_dir}/pretrained_models/models/image_encoder/\"\n",
"ip_ckpt = f\"{current_dir}/pretrained_models/models/ip-adapter_sd15.safetensors\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63ec542f-8474-4f38-9457-073425578073",
"metadata": {},
"outputs": [],
"source": [
"def image_grid(imgs, rows, cols):\n",
" assert len(imgs) == rows*cols\n",
"\n",
" w, h = imgs[0].size\n",
" grid = Image.new('RGB', size=(cols*w, rows*h))\n",
" grid_w, grid_h = grid.size\n",
" \n",
" for i, img in enumerate(imgs):\n",
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
" return grid\n",
"\n",
"noise_scheduler = DDIMScheduler(\n",
" num_train_timesteps=1000,\n",
" beta_start=0.00085,\n",
" beta_end=0.012,\n",
" beta_schedule=\"scaled_linear\",\n",
" clip_sample=False,\n",
" set_alpha_to_one=False,\n",
" steps_offset=1,\n",
")\n",
"vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3849f9d0-5f68-4a49-9190-69dd50720cae",
"metadata": {},
"outputs": [],
"source": [
"# load t2i-adapter\n",
"adapter_model_path = f\"{current_dir}/pretrained_models/sd1.5/diffusers/t2iadapter_depth_sd15v2/\"\n",
"adapter = T2IAdapter.from_pretrained(adapter_model_path, torch_dtype=torch.float16)\n",
"# load SD pipeline\n",
"pipe = StableDiffusionAdapterPipeline.from_pretrained(\n",
" base_model_path,\n",
" adapter=adapter,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec09e937-3904-4d8e-a559-9066502ded36",
"metadata": {},
"outputs": [],
"source": [
"# read image prompt\n",
"image = Image.open(\"assets/images/river.png\")\n",
"depth_map = Image.open(\"assets/structure_controls/depth2.png\")\n",
"image_grid([image.resize((256, 256)), depth_map.resize((256, 256))], 1, 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81b1ab06-d3ed-4a7e-a356-9ddf1a2eecd6",
"metadata": {},
"outputs": [],
"source": [
"# load ip-adapter\n",
"ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b77f52de-a9e4-44e1-aeec-8165414f1273",
"metadata": {},
"outputs": [],
"source": [
"# generate image\n",
"num_samples = 4\n",
"depth_map = [depth_map] * num_samples # a bug of diffuser, we have to set the number by hard code\n",
"images = ip_model.generate(pil_image=image, image=depth_map, num_samples=num_samples, num_inference_steps=50, seed=42)\n",
"grid = image_grid(images, 1, num_samples)\n",
"grid"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "a36078d9-c788-4323-b9af-88225e6c6c94",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, KandinskyV22PriorPipeline\n",
"from PIL import Image\n",
"\n",
"from ip_adapter import IPAdapter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2a71bc9-de68-4de4-b6c3-16c92fac3e45",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"current_dir = os.getcwd()\n",
"\n",
"print(current_dir)\n",
"# TODO\n",
"base_model_path = f\"{current_dir}/pretrained_models/sd1.5/Realistic_Vision_v4.0_noVAE\"\n",
"vae_model_path = f\"{current_dir}/pretrained_models/sd1.5/sd-vae-ft-mse\"\n",
"image_encoder_path = f\"{current_dir}/pretrained_models/sdxl_models/image_encoder/\"\n",
"prior_model_path = f\"{current_dir}/pretrained_models/kandinsky-2-2-prior\"\n",
"ip_ckpt = f\"{current_dir}/pretrained_models/models/ip-adapter_sd15_vit-G.safetensors\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d3092ca-f27e-4491-aacb-f0991f3a30ce",
"metadata": {},
"outputs": [],
"source": [
"def image_grid(imgs, rows, cols):\n",
" assert len(imgs) == rows*cols\n",
"\n",
" w, h = imgs[0].size\n",
" grid = Image.new('RGB', size=(cols*w, rows*h))\n",
" grid_w, grid_h = grid.size\n",
" \n",
" for i, img in enumerate(imgs):\n",
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
" return grid\n",
"\n",
"noise_scheduler = DDIMScheduler(\n",
" num_train_timesteps=1000,\n",
" beta_start=0.00085,\n",
" beta_end=0.012,\n",
" beta_schedule=\"scaled_linear\",\n",
" clip_sample=False,\n",
" set_alpha_to_one=False,\n",
" steps_offset=1,\n",
")\n",
"vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b558ca3-e671-4d10-9137-5bf34f710124",
"metadata": {},
"outputs": [],
"source": [
"# load SD pipeline\n",
"pipe = StableDiffusionPipeline.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "216952a5-f70d-4aec-b705-fb235e540e3d",
"metadata": {},
"outputs": [],
"source": [
"# load Prior pipeline\n",
"pipe_prior = KandinskyV22PriorPipeline.from_pretrained(prior_model_path, torch_dtype=torch.float16).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b02182b7-d3cb-4684-a6dd-8515a7f3f861",
"metadata": {},
"outputs": [],
"source": [
"# load ip-adapter\n",
"ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d0d42e9-6259-48ac-817a-ddf5164cb6ef",
"metadata": {},
"outputs": [],
"source": [
"# generate clip image embeds\n",
"prompt = [\n",
" \"a photograph of an astronaut riding a horse\",\n",
" \"a macro wildlife photo of a green frog in a rainforest pond, highly detailed, eye-level shot\",\n",
" \"kid's coloring book, a happy young girl holding a flower, cartoon, thick lines, black and white, white background\",\n",
" \"a professional photograph of a woman with red and very short hair\",\n",
"]\n",
"clip_image_embeds = pipe_prior(prompt, generator=torch.manual_seed(42)).image_embeds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2097cffc-bf93-44ca-9e6b-9d099604d4e1",
"metadata": {},
"outputs": [],
"source": [
"# generate image\n",
"images = ip_model.generate(clip_image_embeds=clip_image_embeds, num_samples=1, width=512, height=512, num_inference_steps=50, seed=42)\n",
"image_grid(images, 1, 4)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# 将数据转换为json格式,此脚本适用于cc3m
# [
# {"text": "a dog", "image_file": "dog.jpg"}
# ]
import json
from pathlib import Path
def convert_to_json(data_root: str,
save_path: str):
data_root = Path(data_root)
txt_path_list = [*data_root.glob("*.txt")]
image_path_list = [*data_root.glob("*.png"),
*data_root.glob("*.jpg"),
*data_root.glob("*.jpeg")]
text_path_mapping = {
txt_path.stem: txt_path for txt_path in txt_path_list
}
image_path_mapping = {
image_path.stem: image_path for image_path in image_path_list
}
keys = list(set(text_path_mapping.keys()) & set(image_path_mapping.keys()))
results = []
for key in keys:
with open(text_path_mapping[key]) as f:
text = f.read().strip()
results.append({"text": text, "image_file": str(image_path_mapping[key])})
with open(save_path, "w") as f:
json.dump(results, f, ensure_ascii=False)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, help="图像-文本存储位置")
parser.add_argument("--save_path", type=str, help="json文件存储位置")
args = parser.parse_args()
convert_to_json(args.data_root, args.save_path)
\ No newline at end of file
[tool.poetry]
name = "ip-adapter"
version = "0.1.0"
description = "IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models"
authors = ["Ye, Hu", "Zhang, Jun", "Liu, Sibo", "Han, Xiao", "Yang, Wei"]
license = "Apache-2.0"
readme = "README.md"
packages = [{ include = "ip_adapter" }]
[tool.poetry.dependencies]
python = ">=3.6"
[tool.ruff]
line-length = 119
# Deprecation of Cuda 11.6 and Python 3.7 support for PyTorch 2.0
target-version = "py38"
# A list of file patterns to omit from linting, in addition to those specified by exclude.
extend-exclude = ["__pycache__", "*.pyc", "*.egg-info", ".cache"]
select = ["E", "F", "W", "C90", "I", "UP", "B", "C4", "RET", "RUF", "SIM"]
ignore = [
"UP006", # UP006: Use list instead of typing.List for type annotations
"UP007", # UP007: Use X | Y for type annotations
"UP009",
"UP035",
"UP038",
"E402",
"RET504",
]
[tool.isort]
profile = "black"
[tool.black]
line-length = 119
skip-string-normalization = 1
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment