Commit 66e662c1 authored by bailuo's avatar bailuo
Browse files

init & optimize

parents
Pipeline #2116 failed with stages
in 0 seconds
This diff is collapsed.
export SAMPLE_DIR="lora/samples/sculpture"
export OUTPUT_DIR="lora/lora_ckpt/sculpture_lora"
export MODEL_NAME="botp/stable-diffusion-v1-5"
export LORA_RANK=16
accelerate launch lora/train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$SAMPLE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of a sculpture" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=2e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=200 \
--lora_rank=$LORA_RANK \
--seed="0"
# 模型唯一标识
modelCode = 1142
# 模型名称
modelName=dragnoise_pytorch
# 模型描述
modelDescription=利用扩散模型进行基于点的交互式图像编辑。
# 应用场景
appScenario=AIGC,零售,制造,电商,医疗,教育
# 框架类型
frameType=pytorch
# torch==2.0.0
# torchvision==0.15.1
gradio==3.41.1
pydantic==2.0.2
albumentations==1.3.0
opencv-contrib-python
imageio==2.9.0
imageio-ffmpeg
pytorch-lightning
omegaconf
test-tube>=0.7.5
streamlit
einops==0.6.0
transformers==4.33.0
webdataset
kornia
open_clip_torch
invisible-watermark>=0.1.5
streamlit-drawable-canvas==0.8.0
torchmetrics==0.6.0
timm==0.6.12
addict==2.4.0
yapf==0.32.0
prettytable==3.6.0
safetensors==0.3.1
#basicsr==1.4.2
accelerate==0.21.0
decord==0.6.0
diffusers==0.24.0
moviepy==1.0.3
opencv_python==4.7.0.68
Pillow==9.4.0
scikit_image==0.19.3
scipy==1.10.1
tensorboardX==2.6
tqdm>=4.66.3
numpy==1.26.4
#tokenizers>=0.19.1
lpips
#clip
openai-clip
\ No newline at end of file
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
# store the guidance scale to decide whether there are unconditional branch
self.guidance_scale = guidance_scale
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
if self.guidance_scale > 1.0:
qu, qc = q[0:2], q[2:4]
ku, kc = k[0:2], k[2:4]
vu, vc = v[0:2], v[2:4]
# merge queries of source and target branch into one so we can use torch API
qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_u = rearrange(out_u, 'b h n d -> b n (h d)')
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_c = rearrange(out_c, 'b h n d -> b n (h d)')
out = torch.cat([out_u, out_c], dim=0)
else:
q = torch.cat([q[0:1], q[1:2]], dim=2)
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out = rearrange(out, 'b h n d -> b n (h d)')
return out
# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = attn.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = attn.to_out[0]
else:
to_out = attn.to_out
h = attn.heads
q = attn.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = attn.to_k(context)
v = attn.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# the only difference
out = editor(
q, k, v, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
return to_out(out)
return forward
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
# the only difference
hidden_states = editor(
query, key, value, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return forward
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
if attn_processor == 'attn_proc':
net.forward = override_attn_proc_forward(net, editor, place_in_unet)
elif attn_processor == 'lora_attn_proc':
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
else:
raise NotImplementedError("not implemented")
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count
import copy
import time
import math
import numpy as np
import torch
import torch.nn.functional as F
# import copy
from torch.utils.tensorboard import SummaryWriter
def point_tracking(F0,
F1,
handle_points,
handle_points_init,
args):
with torch.no_grad():
for i in range(len(handle_points)):
pi0, pi = handle_points_init[i], handle_points[i]
f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
r1, r2 = int(pi[0])-args.r_p, int(pi[0])+args.r_p+1
c1, c2 = int(pi[1])-args.r_p, int(pi[1])+args.r_p+1
F1_neighbor = F1[:, :, r1:r2, c1:c2]
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
all_dist = all_dist.squeeze(dim=0)
# WARNING: no boundary protection right now
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
handle_points[i][0] = pi[0] - args.r_p + row
handle_points[i][1] = pi[1] - args.r_p + col
return handle_points
def check_handle_reach_target(handle_points,
target_points):
# dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1)
all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points))
return (torch.tensor(all_dist) < 2.0).all()
# obtain the bilinear interpolated feature patch centered around (x, y) with radius r
def interpolate_feature_patch(feat,
y,
x,
r):
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
Ia = feat[:, :, y0-r:y0+r+1, x0-r:x0+r+1]
Ib = feat[:, :, y1-r:y1+r+1, x0-r:x0+r+1]
Ic = feat[:, :, y0-r:y0+r+1, x1-r:x1+r+1]
Id = feat[:, :, y1-r:y1+r+1, x1-r:x1+r+1]
return Ia * wa + Ib * wb + Ic * wc + Id * wd
def drag_diffusion_update(model,
init_code,
t,
handle_points,
target_points,
mask,
args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
text_emb = model.get_text_embeddings(args.prompt).detach()
# the init output feature of unet
with torch.no_grad():
unet_output, F0 , h_feature = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_0,_ = model.step(unet_output, t, init_code)
# prepare optimizable init_code and optimizer
# init_code.requires_grad_(True)
# optimizer = torch.optim.Adam([init_code], lr=args.lr)
h_feature.requires_grad_(True)
optimizer = torch.optim.Adam([h_feature], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
h_features = []
loss_list = []
new_target = []
point = []
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
point.append(pi)
point.append(ti)
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
unet_output, F1, h_feature = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb, h_feature=h_feature,
layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w)
x_prev_updated,_ = model.step(unet_output, t, init_code)
copy_h = copy.deepcopy(h_feature)
h_features.append(copy_h)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('{} _ new handle points:{}'.format(step_idx, handle_points) )
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
pii = pi.tolist()
new_target.append(pii)
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 2.:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
loss_list.append(loss)
# loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code, h_feature, h_features
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.fft as fft
from diffusers.models.unet_2d_condition import logger
from diffusers.utils import is_torch_version
from typing import Any, Dict, List, Optional, Tuple, Union
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
def Fourier_filter(x, threshold, scale):
dtype = x.dtype
x = x.type(torch.float32)
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()
crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
x_filtered = x_filtered.type(dtype)
return x_filtered
def register_upblock2d(model):
def up_forward(self):
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "UpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
def up_forward(self):
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if hidden_states.shape[1] == 1280:
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if hidden_states.shape[1] == 640:
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
# ---------------------------------------------------------
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "UpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
setattr(upsample_block, 'b1', b1)
setattr(upsample_block, 'b2', b2)
setattr(upsample_block, 's1', s1)
setattr(upsample_block, 's2', s2)
def register_crossattn_upblock2d(model):
def up_forward(self):
def forward(
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
def up_forward(self):
def forward(
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if hidden_states.shape[1] == 1280:
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if hidden_states.shape[1] == 640:
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
# ---------------------------------------------------------
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
# hidden_states = attn(
# hidden_states,
# encoder_hidden_states=encoder_hidden_states,
# cross_attention_kwargs=cross_attention_kwargs,
# encoder_attention_mask=encoder_attention_mask,
# return_dict=False,
# )[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return forward
for i, upsample_block in enumerate(model.unet.up_blocks):
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
upsample_block.forward = up_forward(upsample_block)
setattr(upsample_block, 'b1', b1)
setattr(upsample_block, 'b2', b2)
setattr(upsample_block, 's1', s1)
setattr(upsample_block, 's2', s2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment