Commit 7dc08a7d authored by bailuo's avatar bailuo
Browse files

init

parents
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
# store the guidance scale to decide whether there are unconditional branch
self.guidance_scale = guidance_scale
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
if self.guidance_scale > 1.0:
qu, qc = q[0:2], q[2:4]
ku, kc = k[0:2], k[2:4]
vu, vc = v[0:2], v[2:4]
# merge queries of source and target branch into one so we can use torch API
qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_u = rearrange(out_u, 'b h n d -> b n (h d)')
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_c = rearrange(out_c, 'b h n d -> b n (h d)')
out = torch.cat([out_u, out_c], dim=0)
else:
q = torch.cat([q[0:1], q[1:2]], dim=2)
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out = rearrange(out, 'b h n d -> b n (h d)')
return out
# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = attn.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = attn.to_out[0]
else:
to_out = attn.to_out
h = attn.heads
q = attn.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = attn.to_k(context)
v = attn.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# the only difference
out = editor(
q, k, v, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
return to_out(out)
return forward
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# query = attn.to_q(hidden_states) + lora_scale * attn.to_q.lora_layer(hidden_states)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states) + lora_scale * attn.to_k.lora_layer(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states) + lora_scale * attn.to_v.lora_layer(encoder_hidden_states)
key, value = attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
# the only difference
hidden_states = editor(
query, key, value, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
# linear proj
# hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.to_out[0].lora_layer(hidden_states)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return forward
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
if attn_processor == 'attn_proc':
net.forward = override_attn_proc_forward(net, editor, place_in_unet)
elif attn_processor == 'lora_attn_proc':
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
else:
raise NotImplementedError("not implemented")
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import copy
import torch
import torch.nn.functional as F
def point_tracking(F0,
F1,
handle_points,
handle_points_init,
args):
with torch.no_grad():
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi0, pi = handle_points_init[i], handle_points[i]
f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
r1, r2 = max(0,int(pi[0])-args.r_p), min(max_r,int(pi[0])+args.r_p+1)
c1, c2 = max(0,int(pi[1])-args.r_p), min(max_c,int(pi[1])+args.r_p+1)
F1_neighbor = F1[:, :, r1:r2, c1:c2]
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
all_dist = all_dist.squeeze(dim=0)
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
# handle_points[i][0] = pi[0] - args.r_p + row
# handle_points[i][1] = pi[1] - args.r_p + col
handle_points[i][0] = r1 + row
handle_points[i][1] = c1 + col
return handle_points
def check_handle_reach_target(handle_points,
target_points):
# dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1)
all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points))
return (torch.tensor(all_dist) < 2.0).all()
# obtain the bilinear interpolated feature patch centered around (x, y) with radius r
def interpolate_feature_patch(feat,
y1,
y2,
x1,
x2):
x1_floor = torch.floor(x1).long()
x1_cell = x1_floor + 1
dx = torch.floor(x2).long() - torch.floor(x1).long()
y1_floor = torch.floor(y1).long()
y1_cell = y1_floor + 1
dy = torch.floor(y2).long() - torch.floor(y1).long()
wa = (x1_cell.float() - x1) * (y1_cell.float() - y1)
wb = (x1_cell.float() - x1) * (y1 - y1_floor.float())
wc = (x1 - x1_floor.float()) * (y1_cell.float() - y1)
wd = (x1 - x1_floor.float()) * (y1 - y1_floor.float())
Ia = feat[:, :, y1_floor : y1_floor+dy, x1_floor : x1_floor+dx]
Ib = feat[:, :, y1_cell : y1_cell+dy, x1_floor : x1_floor+dx]
Ic = feat[:, :, y1_floor : y1_floor+dy, x1_cell : x1_cell+dx]
Id = feat[:, :, y1_cell : y1_cell+dy, x1_cell : x1_cell+dx]
return Ia * wa + Ib * wb + Ic * wc + Id * wd
def drag_diffusion_update(model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
if text_embeddings is None:
text_embeddings = model.get_text_embeddings(args.prompt)
# the init output feature of unet
with torch.no_grad():
unet_output, F0 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_0,_ = model.step(unet_output, t, init_code)
# init_code_orig = copy.deepcopy(init_code)
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
unet_output, F1 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_updated,_ = model.step(unet_output, t, init_code)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('new handle points', handle_points)
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 2.:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
# with boundary protection
r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1)
c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1)
f0_patch = F1[:,:,r1:r2, c1:c2].detach()
f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1])
# original code, without boundary protection
# f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
# f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code
def drag_diffusion_update_gen(model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
if text_embeddings is None:
text_embeddings = model.get_text_embeddings(args.prompt)
# positive prompt embedding
if args.guidance_scale > 1.0:
unconditional_input = model.tokenizer(
[args.neg_prompt],
padding="max_length",
max_length=77,
return_tensors="pt"
)
unconditional_emb = model.text_encoder(unconditional_input.input_ids.to(text_embeddings.device))[0].detach()
text_embeddings = torch.cat([unconditional_emb, text_embeddings], dim=0)
# the init output feature of unet
with torch.no_grad():
if args.guidance_scale > 1.:
model_inputs_0 = copy.deepcopy(torch.cat([init_code] * 2))
else:
model_inputs_0 = copy.deepcopy(init_code)
unet_output, F0 = model.forward_unet_features(model_inputs_0, t, encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
if args.guidance_scale > 1.:
# strategy 1: discard the unconditional branch feature maps
# F0 = F0[1].unsqueeze(dim=0)
# strategy 2: concat pos and neg branch feature maps for motion-sup and point tracking
# F0 = torch.cat([F0[0], F0[1]], dim=0).unsqueeze(dim=0)
# strategy 3: concat pos and neg branch feature maps with guidance_scale consideration
coef = args.guidance_scale / (2*args.guidance_scale - 1.0)
F0 = torch.cat([(1-coef)*F0[0], coef*F0[1]], dim=0).unsqueeze(dim=0)
unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0)
unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon)
x_prev_0,_ = model.step(unet_output, t, init_code)
# init_code_orig = copy.deepcopy(init_code)
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
if args.guidance_scale > 1.:
model_inputs = init_code.repeat(2,1,1,1)
else:
model_inputs = init_code
unet_output, F1 = model.forward_unet_features(model_inputs, t, encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
if args.guidance_scale > 1.:
# strategy 1: discard the unconditional branch feature maps
# F1 = F1[1].unsqueeze(dim=0)
# strategy 2: concat positive and negative branch feature maps for motion-sup and point tracking
# F1 = torch.cat([F1[0], F1[1]], dim=0).unsqueeze(dim=0)
# strategy 3: concat pos and neg branch feature maps with guidance_scale consideration
coef = args.guidance_scale / (2*args.guidance_scale - 1.0)
F1 = torch.cat([(1-coef)*F1[0], coef*F1[1]], dim=0).unsqueeze(dim=0)
unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0)
unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon)
x_prev_updated,_ = model.step(unet_output, t, init_code)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('new handle points', handle_points)
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 2.:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
# with boundary protection
r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1)
c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1)
f0_patch = F1[:,:,r1:r2, c1:c2].detach()
f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1])
# original code, without boundary protection
# f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
# f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.fft as fft
from diffusers.models.unet_2d_condition import logger
from diffusers.utils import is_torch_version
from typing import Any, Dict, List, Optional, Tuple, Union
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
def Fourier_filter(x, threshold, scale):
dtype = x.dtype
x = x.type(torch.float32)
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()
crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
x_filtered = x_filtered.type(dtype)
return x_filtered
def register_upblock2d(model):
def up_forward(self):
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "UpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
def up_forward(self):
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if hidden_states.shape[1] == 1280:
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if hidden_states.shape[1] == 640:
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
# ---------------------------------------------------------
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "UpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
setattr(upsample_block, 'b1', b1)
setattr(upsample_block, 'b2', b2)
setattr(upsample_block, 's1', s1)
setattr(upsample_block, 's2', s2)
def register_crossattn_upblock2d(model):
def up_forward(self):
def forward(
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
def up_forward(self):
def forward(
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if hidden_states.shape[1] == 1280:
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if hidden_states.shape[1] == 640:
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
# ---------------------------------------------------------
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
# hidden_states = attn(
# hidden_states,
# encoder_hidden_states=encoder_hidden_states,
# cross_attention_kwargs=cross_attention_kwargs,
# encoder_attention_mask=encoder_attention_mask,
# return_dict=False,
# )[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
setattr(upsample_block, 'b1', b1)
setattr(upsample_block, 'b2', b2)
setattr(upsample_block, 's1', s1)
setattr(upsample_block, 's2', s2)
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
from PIL import Image
import os
import numpy as np
from einops import rearrange
import torch
import torch.nn.functional as F
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0")
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
max_length = tokenizer.model_max_length
text_inputs = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
return text_inputs
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
attention_mask = attention_mask.to(text_encoder.device)
else:
attention_mask = None
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
return prompt_embeds
# model_path: path of the model
# image: input image, have not been pre-processed
# save_lora_path: the path to save the lora
# prompt: the user input prompt
# lora_step: number of lora training step
# lora_lr: learning rate of lora training
# lora_rank: the rank of lora
# save_interval: the frequency of saving lora checkpoints
def train_lora(image,
prompt,
model_path,
vae_path,
save_lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress,
save_interval=-1):
# initialize accelerator
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision='fp16'
)
set_seed(0)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
subfolder="tokenizer",
revision=None,
use_fast=False,
)
# initialize the model
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
text_encoder = text_encoder_cls.from_pretrained(
model_path, subfolder="text_encoder", revision=None
)
if vae_path == "default":
vae = AutoencoderKL.from_pretrained(
model_path, subfolder="vae", revision=None
)
else:
vae = AutoencoderKL.from_pretrained(vae_path)
unet = UNet2DConditionModel.from_pretrained(
model_path, subfolder="unet", revision=None
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model_path,
vae=vae,
unet=unet,
text_encoder=text_encoder,
scheduler=noise_scheduler,
torch_dtype=torch.float16)
# set device and dtype
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
unet.to(device, dtype=torch.float16)
vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16)
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=lora_rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=lora_rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=lora_rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=lora_rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
# Optimizer creation
params_to_optimize = (unet_lora_parameters)
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=lora_lr,
betas=(0.9, 0.999),
weight_decay=1e-2,
eps=1e-08,
)
lr_scheduler = get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=lora_step,
num_cycles=1,
power=1.0,
)
# prepare accelerator
# unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
# optimizer = accelerator.prepare_optimizer(optimizer)
# lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
unet,optimizer,lr_scheduler = accelerator.prepare(unet,optimizer,lr_scheduler)
# initialize text embeddings
with torch.no_grad():
text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
text_embedding = encode_prompt(
text_encoder,
text_inputs.input_ids,
text_inputs.attention_mask,
text_encoder_use_attention_mask=False
)
text_embedding = text_embedding.repeat(lora_batch_size, 1, 1)
# initialize image transforms
image_transforms_pil = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
]
)
image_transforms_tensor = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for step in progress.tqdm(range(lora_step), desc="training LoRA"):
unet.train()
image_batch = []
image_pil_batch = []
for _ in range(lora_batch_size):
# first store pil image
image_transformed = image_transforms_pil(Image.fromarray(image))
image_pil_batch.append(image_transformed)
# then store tensor image
image_transformed = image_transforms_tensor(image_transformed).to(device, dtype=torch.float16)
image_transformed = image_transformed.unsqueeze(dim=0)
image_batch.append(image_transformed)
# repeat the image_transformed to enable multi-batch training
image_batch = torch.cat(image_batch, dim=0)
latents_dist = vae.encode(image_batch).latent_dist
model_input = latents_dist.sample() * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_pred = unet(noisy_model_input,
timesteps,
text_embedding).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if save_interval > 0 and (step + 1) % save_interval == 0:
save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1))
if not os.path.isdir(save_lora_path_intermediate):
os.mkdir(save_lora_path_intermediate)
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_lora_state_dict(unet)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path_intermediate,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
)
# unet = unet.to(torch.float16)
# save the trained lora
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_lora_state_dict(unet)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
)
return
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import cv2
import numpy as np
import gradio as gr
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
import datetime
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.models.embeddings import ImageProjection
from drag_pipeline import DragPipeline
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
from .drag_utils import drag_diffusion_update, drag_diffusion_update_gen
from .lora_utils import train_lora
from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
from .freeu_utils import register_free_upblock2d, register_free_crossattn_upblock2d
# -------------- general UI functionality --------------
def clear_all(length=480):
return gr.Image.update(value=None, height=length, width=length, interactive=True), \
gr.Image.update(value=None, height=length, width=length, interactive=False), \
gr.Image.update(value=None, height=length, width=length, interactive=False), \
[], None, None
def clear_all_gen(length=480):
return gr.Image.update(value=None, height=length, width=length, interactive=False), \
gr.Image.update(value=None, height=length, width=length, interactive=False), \
gr.Image.update(value=None, height=length, width=length, interactive=False), \
[], None, None, None
def mask_image(image,
mask,
color=[255,0,0],
alpha=0.5):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
return out
def store_img(img, length=512):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
height,width,_ = image.shape
image = Image.fromarray(image)
image = exif_transpose(image)
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], gr.Image.update(value=masked_img, interactive=True), mask
# once user upload an image, the original image is stored in `original_image`
# the same image is displayed in `input_image` for point clicking purpose
def store_img_gen(img):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
image = Image.fromarray(image)
image = exif_transpose(image)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# user click the image to get points, and show the points on the image
def get_points(img,
sel_pix,
evt: gr.SelectData):
# collect the selected point
sel_pix.append(evt.index)
# draw points
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
# draw a red circle at the handle point
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
else:
# draw a blue circle at the handle point
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
points.append(tuple(point))
# draw an arrow from handle point to target point
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
points = []
return img if isinstance(img, np.ndarray) else np.array(img)
# clear all handle/target points
def undo_points(original_image,
mask):
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = original_image.copy()
return masked_img, []
# ------------------------------------------------------
# ----------- dragging user-input image utils -----------
def train_lora_interface(original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress=gr.Progress()):
train_lora(
original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress)
return "Training LoRA Done!"
def preprocess_image(image,
device,
dtype=torch.float32):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device, dtype)
return image
def run_drag(source_image,
image_with_clicks,
mask,
prompt,
points,
inversion_strength,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
save_dir="./results"
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16)
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# off load model to cpu, which save some memory.
model.enable_model_cpu_offload()
# initialize parameters
seed = 42 # random seed used by a lot of people for unknown reason
seed_everything(seed)
args = SimpleNamespace()
args.prompt = prompt
args.points = points
args.n_inference_step = 50
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0
args.unet_feature_idx = [3]
args.r_m = 1
args.r_p = 3
args.lam = lam
args.lr = latent_lr
args.n_pix_step = n_pix_step
full_h, full_w = source_image.shape[:2]
args.sup_res_h = int(0.5*full_h)
args.sup_res_w = int(0.5*full_w)
print(args)
source_image = preprocess_image(source_image, device, dtype=torch.float16)
image_with_clicks = preprocess_image(image_with_clicks, device)
# preparing editing meta data (handle, target, mask)
mask = torch.from_numpy(mask).float() / 255.
mask[mask > 0.0] = 1.0
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
handle_points = []
target_points = []
# here, the point is in x,y coordinate
for idx, point in enumerate(points):
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
cur_point = torch.round(cur_point)
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
print('handle points:', handle_points)
print('target points:', target_points)
# set lora
if lora_path == "":
print("applying default parameters")
model.unet.set_default_attn_processor()
else:
print("applying lora: " + lora_path)
model.unet.load_attn_procs(lora_path)
# obtain text embeddings
text_embeddings = model.get_text_embeddings(prompt)
# invert the source image
# the latent code resolution is too small, only 64*64
invert_code = model.invert(source_image,
prompt,
encoder_hidden_states=text_embeddings,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step)
# empty cache to save memory
torch.cuda.empty_cache()
init_code = invert_code
init_code_orig = deepcopy(init_code)
model.scheduler.set_timesteps(args.n_inference_step)
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# convert dtype to float for optimization
init_code = init_code.float()
text_embeddings = text_embeddings.float()
model.unet = model.unet.float()
updated_init_code = drag_diffusion_update(
model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args)
updated_init_code = updated_init_code.half()
text_embeddings = text_embeddings.half()
model.unet = model.unet.half()
# empty cache to save memory
torch.cuda.empty_cache()
# hijack the attention module
# inject the reference branch to guide the generation
editor = MutualSelfAttentionControl(start_step=start_step,
start_layer=start_layer,
total_steps=args.n_inference_step,
guidance_scale=args.guidance_scale)
if lora_path == "":
register_attention_editor_diffusers(model, editor, attn_processor='attn_proc')
else:
register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc')
# inference the synthesized image
gen_image = model(
prompt=args.prompt,
encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0),
batch_size=2,
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)[1].unsqueeze(dim=0)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
# save the original image, user editing instructions, synthesized image
save_result = torch.cat([
source_image.float() * 0.5 + 0.5,
torch.ones((1,3,full_h,25)).cuda(),
image_with_clicks.float() * 0.5 + 0.5,
torch.ones((1,3,full_h,25)).cuda(),
gen_image[0:1].float()
], dim=-1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
# -------------------------------------------------------
# ----------- dragging generated image utils -----------
# once the user generated an image
# it will be displayed on mask drawing-areas and point-clicking area
def gen_img(
length, # length of the window displaying the image
height, # height of the generated image
width, # width of the generated image
n_inference_step,
scheduler_name,
seed,
guidance_scale,
prompt,
neg_prompt,
model_path,
vae_path,
lora_path,
b1,
b2,
s1,
s2):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device)
if scheduler_name == "DDIM":
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
elif scheduler_name == "DPM++2M":
scheduler = DPMSolverMultistepScheduler.from_config(
model.scheduler.config
)
elif scheduler_name == "DPM++2M_karras":
scheduler = DPMSolverMultistepScheduler.from_config(
model.scheduler.config, use_karras_sigmas=True
)
else:
raise NotImplementedError("scheduler name not correct")
model.scheduler = scheduler
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# set lora
#if lora_path != "":
# print("applying lora for image generation: " + lora_path)
# model.unet.load_attn_procs(lora_path)
if lora_path != "":
print("applying lora: " + lora_path)
model.load_lora_weights(lora_path, weight_name="lora.safetensors")
# apply FreeU
if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0:
print('applying FreeU')
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2)
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2)
else:
print('do not apply FreeU')
# initialize init noise
seed_everything(seed)
init_noise = torch.randn([1, 4, height // 8, width // 8], device=device, dtype=model.vae.dtype)
gen_image, intermediate_latents = model(prompt=prompt,
neg_prompt=neg_prompt,
num_inference_steps=n_inference_step,
latents=init_noise,
guidance_scale=guidance_scale,
return_intermediates=True)
gen_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
gen_image = (gen_image * 255).astype(np.uint8)
if height < width:
# need to do this due to Gradio's bug
return gr.Image.update(value=gen_image, height=int(length*height/width), width=length, interactive=True), \
gr.Image.update(height=int(length*height/width), width=length, interactive=True), \
gr.Image.update(height=int(length*height/width), width=length), \
None, \
intermediate_latents
else:
return gr.Image.update(value=gen_image, height=length, width=length, interactive=True), \
gr.Image.update(value=None, height=length, width=length, interactive=True), \
gr.Image.update(value=None, height=length, width=length), \
None, \
intermediate_latents
def run_drag_gen(
n_inference_step,
scheduler_name,
source_image,
image_with_clicks,
intermediate_latents_gen,
guidance_scale,
mask,
prompt,
neg_prompt,
points,
inversion_strength,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
b1,
b2,
s1,
s2,
save_dir="./results"):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
if scheduler_name == "DDIM":
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
elif scheduler_name == "DPM++2M":
scheduler = DPMSolverMultistepScheduler.from_config(
model.scheduler.config
)
elif scheduler_name == "DPM++2M_karras":
scheduler = DPMSolverMultistepScheduler.from_config(
model.scheduler.config, use_karras_sigmas=True
)
else:
raise NotImplementedError("scheduler name not correct")
model.scheduler = scheduler
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# off load model to cpu, which save some memory.
model.enable_model_cpu_offload()
# initialize parameters
seed = 42 # random seed used by a lot of people for unknown reason
seed_everything(seed)
args = SimpleNamespace()
args.prompt = prompt
args.neg_prompt = neg_prompt
args.points = points
args.n_inference_step = n_inference_step
args.n_actual_inference_step = round(n_inference_step * inversion_strength)
args.guidance_scale = guidance_scale
args.unet_feature_idx = [3]
full_h, full_w = source_image.shape[:2]
args.sup_res_h = int(0.5*full_h)
args.sup_res_w = int(0.5*full_w)
args.r_m = 1
args.r_p = 3
args.lam = lam
args.lr = latent_lr
args.n_pix_step = n_pix_step
print(args)
source_image = preprocess_image(source_image, device)
image_with_clicks = preprocess_image(image_with_clicks, device)
if lora_path != "":
print("applying lora: " + lora_path)
model.load_lora_weights(lora_path, weight_name="lora.safetensors")
# preparing editing meta data (handle, target, mask)
mask = torch.from_numpy(mask).float() / 255.
mask[mask > 0.0] = 1.0
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
handle_points = []
target_points = []
# here, the point is in x,y coordinate
for idx, point in enumerate(points):
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
cur_point = torch.round(cur_point)
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
print('handle points:', handle_points)
print('target points:', target_points)
# apply FreeU
if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0:
print('applying FreeU')
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2)
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2)
else:
print('do not apply FreeU')
# obtain text embeddings
text_embeddings = model.get_text_embeddings(prompt)
model.scheduler.set_timesteps(args.n_inference_step)
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
init_code = deepcopy(intermediate_latents_gen[args.n_inference_step - args.n_actual_inference_step])
init_code_orig = deepcopy(init_code)
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
torch.cuda.empty_cache()
init_code = init_code.to(torch.float32)
text_embeddings = text_embeddings.to(torch.float32)
model.unet = model.unet.to(torch.float32)
updated_init_code = drag_diffusion_update_gen(model, init_code,
text_embeddings, t, handle_points, target_points, mask, args)
updated_init_code = updated_init_code.to(torch.float16)
text_embeddings = text_embeddings.to(torch.float16)
model.unet = model.unet.to(torch.float16)
torch.cuda.empty_cache()
# hijack the attention module
# inject the reference branch to guide the generation
editor = MutualSelfAttentionControl(start_step=start_step,
start_layer=start_layer,
total_steps=args.n_inference_step,
guidance_scale=args.guidance_scale)
if lora_path == "":
register_attention_editor_diffusers(model, editor, attn_processor='attn_proc')
else:
register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc')
# inference the synthesized image
gen_image = model(
prompt=args.prompt,
neg_prompt=args.neg_prompt,
batch_size=2, # batch size is 2 because we have reference init_code and updated init_code
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)[1].unsqueeze(dim=0)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
# save the original image, user editing instructions, synthesized image
save_result = torch.cat([
source_image * 0.5 + 0.5,
torch.ones((1,3,full_h,25)).cuda(),
image_with_clicks * 0.5 + 0.5,
torch.ones((1,3,full_h,25)).cuda(),
gen_image[0:1]
], dim=-1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
# ------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment