Commit 7dc08a7d authored by bailuo's avatar bailuo
Browse files

init

parents
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
# store the guidance scale to decide whether there are unconditional branch
self.guidance_scale = guidance_scale
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
if self.guidance_scale > 1.0:
qu, qc = q[0:2], q[2:4]
ku, kc = k[0:2], k[2:4]
vu, vc = v[0:2], v[2:4]
# merge queries of source and target branch into one so we can use torch API
qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_u = rearrange(out_u, 'b h n d -> b n (h d)')
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_c = rearrange(out_c, 'b h n d -> b n (h d)')
out = torch.cat([out_u, out_c], dim=0)
else:
q = torch.cat([q[0:1], q[1:2]], dim=2)
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out = rearrange(out, 'b h n d -> b n (h d)')
return out
# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = attn.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = attn.to_out[0]
else:
to_out = attn.to_out
h = attn.heads
q = attn.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = attn.to_k(context)
v = attn.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# the only difference
out = editor(
q, k, v, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
return to_out(out)
return forward
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# query = attn.to_q(hidden_states) + lora_scale * attn.to_q.lora_layer(hidden_states)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states) + lora_scale * attn.to_k.lora_layer(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states) + lora_scale * attn.to_v.lora_layer(encoder_hidden_states)
key, value = attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
# the only difference
hidden_states = editor(
query, key, value, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
# linear proj
# hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.to_out[0].lora_layer(hidden_states)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return forward
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
if attn_processor == 'attn_proc':
net.forward = override_attn_proc_forward(net, editor, place_in_unet)
elif attn_processor == 'lora_attn_proc':
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
else:
raise NotImplementedError("not implemented")
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import copy
import torch
import torch.nn.functional as F
def point_tracking(F0,
F1,
handle_points,
handle_points_init,
args):
with torch.no_grad():
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi0, pi = handle_points_init[i], handle_points[i]
f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
r1, r2 = max(0,int(pi[0])-args.r_p), min(max_r,int(pi[0])+args.r_p+1)
c1, c2 = max(0,int(pi[1])-args.r_p), min(max_c,int(pi[1])+args.r_p+1)
F1_neighbor = F1[:, :, r1:r2, c1:c2]
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
all_dist = all_dist.squeeze(dim=0)
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
# handle_points[i][0] = pi[0] - args.r_p + row
# handle_points[i][1] = pi[1] - args.r_p + col
handle_points[i][0] = r1 + row
handle_points[i][1] = c1 + col
return handle_points
def check_handle_reach_target(handle_points,
target_points):
# dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1)
all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points))
return (torch.tensor(all_dist) < 2.0).all()
# obtain the bilinear interpolated feature patch centered around (x, y) with radius r
def interpolate_feature_patch(feat,
y1,
y2,
x1,
x2):
x1_floor = torch.floor(x1).long()
x1_cell = x1_floor + 1
dx = torch.floor(x2).long() - torch.floor(x1).long()
y1_floor = torch.floor(y1).long()
y1_cell = y1_floor + 1
dy = torch.floor(y2).long() - torch.floor(y1).long()
wa = (x1_cell.float() - x1) * (y1_cell.float() - y1)
wb = (x1_cell.float() - x1) * (y1 - y1_floor.float())
wc = (x1 - x1_floor.float()) * (y1_cell.float() - y1)
wd = (x1 - x1_floor.float()) * (y1 - y1_floor.float())
Ia = feat[:, :, y1_floor : y1_floor+dy, x1_floor : x1_floor+dx]
Ib = feat[:, :, y1_cell : y1_cell+dy, x1_floor : x1_floor+dx]
Ic = feat[:, :, y1_floor : y1_floor+dy, x1_cell : x1_cell+dx]
Id = feat[:, :, y1_cell : y1_cell+dy, x1_cell : x1_cell+dx]
return Ia * wa + Ib * wb + Ic * wc + Id * wd
def drag_diffusion_update(model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
if text_embeddings is None:
text_embeddings = model.get_text_embeddings(args.prompt)
# the init output feature of unet
with torch.no_grad():
unet_output, F0 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_0,_ = model.step(unet_output, t, init_code)
# init_code_orig = copy.deepcopy(init_code)
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
unet_output, F1 = model.forward_unet_features(init_code, t,
encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_updated,_ = model.step(unet_output, t, init_code)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('new handle points', handle_points)
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 2.:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
# with boundary protection
r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1)
c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1)
f0_patch = F1[:,:,r1:r2, c1:c2].detach()
f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1])
# original code, without boundary protection
# f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
# f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code
def drag_diffusion_update_gen(model,
init_code,
text_embeddings,
t,
handle_points,
target_points,
mask,
args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
if text_embeddings is None:
text_embeddings = model.get_text_embeddings(args.prompt)
# positive prompt embedding
if args.guidance_scale > 1.0:
unconditional_input = model.tokenizer(
[args.neg_prompt],
padding="max_length",
max_length=77,
return_tensors="pt"
)
unconditional_emb = model.text_encoder(unconditional_input.input_ids.to(text_embeddings.device))[0].detach()
text_embeddings = torch.cat([unconditional_emb, text_embeddings], dim=0)
# the init output feature of unet
with torch.no_grad():
if args.guidance_scale > 1.:
model_inputs_0 = copy.deepcopy(torch.cat([init_code] * 2))
else:
model_inputs_0 = copy.deepcopy(init_code)
unet_output, F0 = model.forward_unet_features(model_inputs_0, t, encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
if args.guidance_scale > 1.:
# strategy 1: discard the unconditional branch feature maps
# F0 = F0[1].unsqueeze(dim=0)
# strategy 2: concat pos and neg branch feature maps for motion-sup and point tracking
# F0 = torch.cat([F0[0], F0[1]], dim=0).unsqueeze(dim=0)
# strategy 3: concat pos and neg branch feature maps with guidance_scale consideration
coef = args.guidance_scale / (2*args.guidance_scale - 1.0)
F0 = torch.cat([(1-coef)*F0[0], coef*F0[1]], dim=0).unsqueeze(dim=0)
unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0)
unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon)
x_prev_0,_ = model.step(unet_output, t, init_code)
# init_code_orig = copy.deepcopy(init_code)
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
if args.guidance_scale > 1.:
model_inputs = init_code.repeat(2,1,1,1)
else:
model_inputs = init_code
unet_output, F1 = model.forward_unet_features(model_inputs, t, encoder_hidden_states=text_embeddings,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
if args.guidance_scale > 1.:
# strategy 1: discard the unconditional branch feature maps
# F1 = F1[1].unsqueeze(dim=0)
# strategy 2: concat positive and negative branch feature maps for motion-sup and point tracking
# F1 = torch.cat([F1[0], F1[1]], dim=0).unsqueeze(dim=0)
# strategy 3: concat pos and neg branch feature maps with guidance_scale consideration
coef = args.guidance_scale / (2*args.guidance_scale - 1.0)
F1 = torch.cat([(1-coef)*F1[0], coef*F1[1]], dim=0).unsqueeze(dim=0)
unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0)
unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon)
x_prev_updated,_ = model.step(unet_output, t, init_code)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('new handle points', handle_points)
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
_, _, max_r, max_c = F0.shape
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 2.:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
# with boundary protection
r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1)
c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1)
f0_patch = F1[:,:,r1:r2, c1:c2].detach()
f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1])
# original code, without boundary protection
# f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
# f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment