Commit 66e662c1 authored by bailuo's avatar bailuo
Browse files

init & optimize

parents
Pipeline #2116 failed with stages
in 0 seconds
# DragNoise
DragNoise 模型,利用扩散模型进行基于点的交互式图像编辑,强大而快速。
## 论文
`Drag Your Noise: Interactive Point-based Editing via Diffusion Semantic Propagation`
- https://arxiv.org/abs/2404.01050
- CVPR 2024
## 模型结构
<!-- 此处一句话简要介绍模型结构 -->
<!-- DragNoise利用扩散模型进行基于点的交互式图像编辑。主要是基于 StyleGAN 模型架构: -->
<div align=center>
<img src="./doc/overview.png"/>
<div >DragNoise</div>
</div>
## 算法原理
<!-- 我们介绍了Dragnoise,提供了强大而加速的编辑,而无需重试潜在地图。Dragnoise的核心原理在于利用每个U-NET作为语义编辑器的预测噪声输出。这种方法以两个关键观察为基础:首先,U-NET的瓶颈本质上具有具有互动编辑的理想的语义丰富特征;其次,高级语义是在脱诺过程早期建立的,在后续阶段显示出最小的差异。利用这些见解,在单个降解步骤中进行了Dragnoise编辑扩散语义,并有效地传播了这些变化,从而确保了扩散编辑的稳定性和效率。 -->
DragNoise 算法整体思路沿着 DragDiffusion 算法,与此不同的是针对扩散模型中的 “middle-block replacement” 的操作进行探索。该操作从某个去噪时间步开始,将不同层的特征复制到所有后续 timestep 的对应层。通过观察 DDIM inversion 重建图像的效果,探索扩散模型在何时以及何处学习到何种层次的语义信息。
通过实验发现,bottleneck 特征是一种最优扩散语义表示,适合于高效编辑。由于它可以在早期 timestep 中有效地被编辑,因此操纵 bottleneck 特征可以平滑地传播到后面的去噪步骤,从而确保结果图像扩散语义的完整性。此外,由于优化 bottleneck 的路径短,有效地避免了梯度消失问题。
<!-- <div align=center>
<img src="./doc/pipeline.png"/>
<div >DragNoise</div>
</div> -->
## 环境配置
```
mv dragnoise_pytorch dragnoise # 去框架名后缀
# docker 的 -v 路径、docker_name 和 imageID 根据实际情况修改
# pip 安装时如果出现下载慢可以尝试别的镜像源
```
### Docker(方法一)
<!-- 此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤 -->
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 # 本镜像imageID为:2f1f619d0182
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/dragnoise
pip install -r requirements.txt
```
### Dockerfile(方法二)
<!-- 此处提供dockerfile的使用方法 -->
```
cd /your_code_path/dragnoise/docker
docker build --no-cache -t codestral:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/dragnoise
pip install -r requirements.txt
```
### Anaconda(方法三)
<!-- 此处提供本地配置、编译的详细步骤,例如: -->
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动: dtk24.04.2
python: python3.10
pytorch: 2.1.0
```
`Tips:以上DTK驱动、python、pytorch等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 数据集
测试数据集 [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) 或者从 [`SCNet`](http://113.200.138.88:18080/aidatasets/dragbench) 上下载。\
下载后放在 ./drag_bench_evaluation/drag_bench_data 并解压,文件构成:
<br>
DragBench<br>
--- animals<br>
------ JH_2023-09-14-1820-16<br>
------ JH_2023-09-14-1821-23<br>
------ JH_2023-09-14-1821-58<br>
------ ...<br>
--- art_work<br>
--- building_city_view<br>
--- ...<br>
--- other_objects<br>
<br>
## 训练
```
# LoRA微调
python run_lora_training.py
# 训练drag扩散模型,并输出结果
python run_drag_diffusion.py
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
```
亦或者webui界面。
## 推理
<!-- 下载模型权重:
```
python scripts/download_model.py
```
或者从 [SCNet](http://113.200.138.88:18080/aimodels/findsource-dependency/stylegan2_pytorch) 上快速下载,并放在 /checkpoints 文件夹下。 -->
可视化webui推理:
```
python drag_ui.py --listen
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
```
<div align=center>
<img src="./doc/webui.jpg" width=600/>
<div >webui界面</div>
</div>
1、上传图片;\
2、输入提示;\
3、LoRA训练;\
4、通过鼠标选择要编辑的区域;\
5、通过鼠标标记点位;\
6、运行。\
ps:Drag以及LoRA的一些参数自行视情况修改。
## result
<!-- 此处填算法效果测试图(包括输入、输出) -->
<div align=center>
<img src="./doc/input.jpg" width=600/>
<div >输入</div>
</div>
<div align=center>
<img src="./doc/result.png" width=600/>
<div >推理结果</div>
</div>
<!-- <div align=center>
<img src="./doc/image (1).png" width=600/>
<div >输出</div>
</div> -->
### 精度
测试集 `DragBench`,如上所述下载并解压好。
```
python run_lora_training.py
python run_drag_diffusion.py
python run_eval_similarity.py
python run_eval_point_matching.py
# ps:上述脚本的一些文件路径自行根据情况修改
```
| 加速卡K100_AI | 1-lpips ↑ | clip sim ↑| mean distance ↓ |
| :-----| :----- | :---- | :---- |
| paper | 0.894 | 0.971 | 33.404 |
| 优化后 | 0.885 | 0.967 | 30.206 |
<!-- | 单元格 | 单元格 | 单元格 | -->
ps:优化两点 1)lora训练部分的rank选择;2)drag部分采用了多层融合。
<!-- <div align=center>
<img src="./doc/result2.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/result3.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/result4.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/result5.jpg" width=600/>
<div >paper结果和优化后的结果</div>
</div> -->
## 应用场景
### 算法类别
<!-- 超出以上分类的类别命名也可参考此网址中的类别名:https://huggingface.co/ \ -->
`AIGC`
### 热点应用行业
<!-- 应用行业的填写需要做大量调研,从而为使用者提供专业、全面的推荐,除特殊算法,通常推荐数量>=3。 -->
`零售,制造,电商,医疗,教育`
<!-- ## 预训练权重 -->
<!-- - 此处填写预训练权重在公司内部的下载地址(预训练权重存放中心为:[SCNet AIModels](http://113.200.138.88:18080/aimodels) ,模型用到的各预训练权重请分别填上具体地址。),过小权重文件可打包到项目里。
- 此处填写公开预训练权重官网下载地址(非必须)。 -->
## 源码仓库及问题反馈
<!-- - 此处填本项目gitlab地址 -->
- https://developer.sourcefind.cn/codes/modelzoo/dragnoise_pytorch
## 参考资料
- https://github.com/haofengl/DragNoise
<p align="center">
<h1 align="center">Drag Your Noise: Interactive Point-based Editing via Diffusion Semantic Propagation</h1>
<p align="center">
<strong>Haofeng Liu</strong>
&nbsp;&nbsp;
<strong>Chenshu Xu</strong>
&nbsp;&nbsp;
<strong>Yifei Yang</strong>
&nbsp;&nbsp;
<strong>Lihua Zeng</strong>
&nbsp;&nbsp;
<a href="http://www.shengfenghe.com/"><strong>Shengfeng He</strong></a>
</p>
<p align="center">
<a href="https://arxiv.org/abs/2404.01050"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2404.01050-b31b1b.svg"></a>
<a href="https://www.youtube.com/watch?v=gKq0s_CvCAg&t=1s"><img alt='page' src="https://img.shields.io/badge/YouTube-orange"></a>
<a href="https://space.bilibili.com/386002941?spm_id_from=333.1007.0.0"><img alt='page' src="https://img.shields.io/badge/Bilibili-red"></a>
</p>
<div align="center">
<img src="./image/GIF/robot-noise.gif", width="23%">
<img src="./image/GIF/robot-diffusion.gif", width="23%">
<img src="./image/GIF/mountain-noise.gif", width="23%">
<img src="./image/GIF/mountain-diffusion.gif", width="23%">
<p align="left">&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion</p>
</div>
<div align="center">
<img src="./image/GIF/girl-Noise.gif", width="23%">
<img src="./image/GIF/girl-diffusion.gif", width="23%">
<img src="./image/GIF/cake-noise-min.gif", width="23%">
<img src="./image/GIF/cake-diffusion-min.gif", width="23%">
<p align="left">&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion</p>
</div>
<div align="center">
<img src="./image/GIF/tom-noise-min.gif", width="23%">
<img src="./image/GIF/tom-diffusion-min.gif", width="23%">
<img src="./image/GIF/oldman-noise.gif", width="22.4%">
<img src="./image/GIF/oldman-diffusion-min.gif", width="22.4%">
<p align="left">&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragDiffusion</p>
</div>
<div align="center">
<img src="./image/GIF/boy-noise.gif", width="23%">
<img src="./image/GIF/mouth-noise.gif", width="23%">
<img src="./image/GIF/tiger-noise-min.gif", width="17.1%">
<img src="./image/GIF/road-noise-min.gif", width="26.4%">
<p align="left">&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;DragNoise</p>
</div>
</p>
## News and Update
* [Apr 5th] v1.0.0 Release.
## Installation
It is recommended to run our code on a Nvidia GPU with a linux system. Currently, it requires around 14 GB GPU memory to run our method.
To install the required libraries, simply run the following command:
```
conda env create -f environment.yaml
conda activate dragnoise
```
## Run DragNoise
To start with, in command line, run the following to start the gradio user interface:
```
python3 drag_ui.py
```
Basically, it consists of the following steps:
### Dragging Input Real Images
#### 1) train a LoRA
* Drop our input image into the left-most box.
* Input a prompt describing the image in the "prompt" field
* Click the "Train LoRA" button to train a LoRA given the input image
#### 2) do "drag" editing
* Draw a mask in the left-most box to specify the editable areas. (optional)
* Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point".
* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
## More result
<a id="more reault">
<div align="center">
<img src="./image/image2.png", width="90%">
<img src="./image/image1.png", width="90%">
<img src="./image/image3.png", width="90%">
<img src="./image/image4.png", width="90%">
<img src="./image/image5.png", width="90%">
</div>
</a>
## License
Code related to the Drag algorithm is under Apache 2.0 license.
## BibTeX
If you find our repo helpful, please consider leaving a star or cite our paper :
```bibtex
@misc{liu2024drag,
title={Drag Your Noise: Interactive Point-based Editing via Diffusion Semantic Propagation},
author={Haofeng Liu and Chenshu Xu and Yifei Yang and Lihua Zeng and Shengfeng He},
year={2024},
eprint={2404.01050},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Contact
For any questions on this project, please contact liuhaofeng2022@163.com
## Acknowledgement
This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). We also benefit from the codebase of [DragDiffusion](https://github.com/Yujun-Shi/DragDiffusion).
## Related Links
* [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
* [DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing](https://github.com/Yujun-Shi/DragDiffusion)
* [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/)
* [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/)
## Common Issues and Solutions
1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path".
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10
ENV DEBIAN_FRONTEND=noninteractive
# COPY requirements.txt requirements.txt
# RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
# How to Evaluate with DragBench
### Step 1: extract dataset
Extract [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) into the folder "drag_bench_data".
Resulting directory hierarchy should look like the following:
<br>
drag_bench_data<br>
--- animals<br>
------ JH_2023-09-14-1820-16<br>
------ JH_2023-09-14-1821-23<br>
------ JH_2023-09-14-1821-58<br>
------ ...<br>
--- art_work<br>
--- building_city_view<br>
--- ...<br>
--- other_objects<br>
<br>
### Step 2: train LoRA.
Train one LoRA on each image in drag_bench_data.
To do this, simply execute "run_lora_training.py".
Trained LoRAs will be saved in "drag_bench_lora"
### Step 3: run dragging results
To run dragging results of DragDiffusion on images in "drag_bench_data", simply execute "run_drag_diffusion.py".
Results will be saved in "drag_diffusion_res".
### Step 4: evaluate mean distance and similarity.
To evaluate LPIPS score before and after dragging, execute "run_eval_similarity.py"
To evaluate mean distance between target points and the final position of handle points (estimated by DIFT), execute "run_eval_point_matching.py"
# Expand the Dataset
Here we also provided the labeling tool used by us in the file "labeling_tool.py".
Run this file to get the user interface for labeling your images with drag instructions.
\ No newline at end of file
# code credit: https://github.com/Tsingularity/dift/blob/main/src/models/dift_sd.py
from diffusers import StableDiffusionPipeline
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import DDIMScheduler
import gc
from PIL import Image
class MyUNet2DConditionModel(UNet2DConditionModel):
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
up_ft_indices,
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
# 5. up
up_ft = {}
for i, upsample_block in enumerate(self.up_blocks):
if i > np.max(up_ft_indices):
break
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
if i in up_ft_indices:
up_ft[i] = sample.detach()
output = {}
output['up_ft'] = up_ft
return output
class OneStepSDPipeline(StableDiffusionPipeline):
@torch.no_grad()
def __call__(
self,
img_tensor,
t,
up_ft_indices,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None
):
device = self._execution_device
latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
t = torch.tensor(t, dtype=torch.long, device=device)
noise = torch.randn_like(latents).to(device)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
unet_output = self.unet(latents_noisy,
t,
up_ft_indices,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs)
return unet_output
class SDFeaturizer:
def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):
unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
onestep_pipe.vae.decoder = None
onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
gc.collect()
onestep_pipe = onestep_pipe.to("cuda")
onestep_pipe.enable_attention_slicing()
# onestep_pipe.enable_xformers_memory_efficient_attention()
self.pipe = onestep_pipe
@torch.no_grad()
def forward(self,
img_tensor,
prompt,
t=261,
up_ft_index=1,
ensemble_size=8):
'''
Args:
img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
prompt: the prompt to use, a string
t: the time step to use, should be an int in the range of [0, 1000]
up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
ensemble_size: the number of repeated images used in the batch to extract features
Return:
unet_ft: a torch tensor in the shape of [1, c, h, w]
'''
img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
prompt_embeds = self.pipe.encode_prompt(
prompt=prompt,
device='cuda',
num_images_per_prompt=1,
do_classifier_free_guidance=False) # [1, 77, dim]
prompt_embeds = prompt_embeds[0].repeat(ensemble_size, 1, 1)
unet_ft_all = self.pipe(
img_tensor=img_tensor,
t=t,
up_ft_indices=[up_ft_index],
prompt_embeds=prompt_embeds)
unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
return unet_ft
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import cv2
import numpy as np
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import os
import gradio as gr
import datetime
import pickle
from copy import deepcopy
LENGTH=480 # length of the square area displaying/editing images
def clear_all(length=480):
return gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
[], None, None
def mask_image(image,
mask,
color=[255,0,0],
alpha=0.5):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
return out
def store_img(img, length=512):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
height,width,_ = image.shape
image = Image.fromarray(image)
image = exif_transpose(image)
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# user click the image to get points, and show the points on the image
def get_points(img,
sel_pix,
evt: gr.SelectData):
# collect the selected point
sel_pix.append(evt.index)
# draw points
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
# draw a red circle at the handle point
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
else:
# draw a blue circle at the handle point
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
points.append(tuple(point))
# draw an arrow from handle point to target point
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
points = []
return img if isinstance(img, np.ndarray) else np.array(img)
# clear all handle/target points
def undo_points(original_image,
mask):
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = original_image.copy()
return masked_img, []
def save_all(category,
source_image,
image_with_clicks,
mask,
labeler,
prompt,
points,
root_dir='./drag_bench_data'):
if not os.path.isdir(root_dir):
os.mkdir(root_dir)
if not os.path.isdir(os.path.join(root_dir, category)):
os.mkdir(os.path.join(root_dir, category))
save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_dir = os.path.join(root_dir, category, save_prefix)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
# save images
Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png'))
Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png'))
# save meta data
meta_data = {
'prompt' : prompt,
'points' : points,
'mask' : mask,
}
with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f:
pickle.dump(meta_data, f)
return save_prefix + " saved!"
with gr.Blocks() as demo:
# UI components for editing real images
with gr.Tab(label="Editing Real Image"):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
with gr.Row():
labeler = gr.Textbox(label="Labeler")
category = gr.Dropdown(value="art_work",
label="Image Category",
choices=[
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
)
prompt = gr.Textbox(label="Prompt")
save_status = gr.Textbox(label="display saving status")
with gr.Row():
undo_button = gr.Button("undo points")
clear_all_button = gr.Button("clear all")
save_button = gr.Button("save")
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
clear_all_button.click(
clear_all,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas,
input_image,
selected_points,
original_image,
mask]
)
save_button.click(
save_all,
[category,
original_image,
input_image,
mask,
labeler,
prompt,
selected_points,],
[save_status]
)
demo.queue().launch(share=True, debug=True)
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# run results_0 of DragDiffusion
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
from diffusers import DDIMScheduler, AutoencoderKL
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
import sys
sys.path.insert(0, '../')
from drag_pipeline import DragPipeline
from utils.drag_utils import drag_diffusion_update
from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
def preprocess_image(image,
device):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device)
return image
# copy the run_drag function to here
def run_drag(source_image,
# image_with_clicks,
mask,
prompt,
points,
inversion_strength,
end_step,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
# save_dir="./results"
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# initialize parameters
seed = 42 # random seed used by a lot of people for unknown reason
seed_everything(seed)
args = SimpleNamespace()
args.prompt = prompt
args.points = points
args.n_inference_step = 50
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0
args.unet_feature_idx = [3,4]
args.r_m = 1
args.r_p = 3
args.end_step = end_step
args.lam = lam
args.lr = latent_lr
args.n_pix_step = n_pix_step
full_h, full_w = source_image.shape[:2]
args.sup_res_h = int(0.5*full_h)
args.sup_res_w = int(0.5*full_w)
print(args)
source_image = preprocess_image(source_image, device)
# image_with_clicks = preprocess_image(image_with_clicks, device)
# set lora
if lora_path == "":
print("nolora applying default parameters")
model.unet.set_default_attn_processor()
else:
print("applying lora: " + lora_path)
model.unet.load_attn_procs(lora_path)
# invert the source image
# the latent code resolution is too small, only 64*64
invert_code = model.invert(source_image,
prompt,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step)
mask = torch.from_numpy(mask).float() / 255.
mask[mask > 0.0] = 1.0
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
handle_points = []
target_points = []
# grads_means = []
# here, the point is in x,y coordinate
for idx, point in enumerate(points):
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
cur_point = torch.round(cur_point)
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
print('handle points:', handle_points)
print('target points:', target_points)
init_code = invert_code
init_code_orig = deepcopy(init_code)
model.scheduler.set_timesteps(args.n_inference_step)
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
# update according to the given supervision
updated_init_code, h_feature, h_features = drag_diffusion_update(model, init_code, t,
handle_points, target_points, mask, args)
# inference the synthesized image
gen_image = model(
prompt=args.prompt,
h_feature=h_feature,
end_step=args.end_step,
batch_size=2,
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)[1].unsqueeze(dim=0)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
# save the original image, user editing instructions, synthesized image
# save_result = torch.cat([
# source_image * 0.5 + 0.5,
# torch.ones((1,3,full_h,25)).cuda(),
# image_with_clicks * 0.5 + 0.5,
# torch.ones((1,3,full_h,25)).cuda(),
# gen_image[0:1]
# ], dim=-1)
# if not os.path.isdir(save_dir):
# os.mkdir(save_dir)
# save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
# save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
if __name__ == '__main__':
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
# assume root_dir and lora_dir are valid directory
root_dir = '/home/bailuo/code/DragNoise/../DragDiffusion/drag_bench_evaluation/drag_bench_data/DragBench'
lora_dir = '/home/bailuo/code/DragNoise/drag_bench_evaluation/drag_bench_lora_lora_rank_list*2'
result_dir = '/home/bailuo/code/DragNoise/drag_bench_evaluation/drag_diffusion_res_lora_rank_list*2_[3,4]'
# mkdir if necessary
if not os.path.isdir(result_dir):
os.mkdir(result_dir)
for cat in all_category:
os.mkdir(os.path.join(result_dir,cat))
grads_means = []
for cat in all_category:
file_dir = os.path.join(root_dir, cat)
for sample_name in os.listdir(file_dir):
if sample_name == '.DS_Store':
continue
sample_path = os.path.join(file_dir, sample_name)
# read image file
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
source_image = np.array(source_image)
# load meta data
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
mask = meta_data['mask']
print(mask.shape)
points = meta_data['points']
# load lora
# using LoRA @ 200 steps
lora_path = os.path.join(lora_dir, cat, sample_name, str(200))
print("applying lora: " + lora_path)
out_image = run_drag(
source_image,
mask,
prompt,
points,
inversion_strength=0.7,
end_step=0,
lam=0.2,
latent_lr=0.02,
n_pix_step=80,
model_path="botp/stable-diffusion-v1-5",
# model_path="/home/bailuo/models/models--botp--stable-diffusion-v1-5",
vae_path="default",
lora_path=lora_path,
start_step=0,
start_layer=10,
)
save_dir = os.path.join(result_dir, cat, sample_name)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# run evaluation of mean distance between the desired target points and the position of final handle points
import os
import pickle
import numpy as np
import PIL
from PIL import Image
from torchvision.transforms import PILToTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from dift_sd import SDFeaturizer
from pytorch_lightning import seed_everything
if __name__ == '__main__':
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# using SD-2.1
dift = SDFeaturizer('stabilityai/stable-diffusion-2-1')
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
original_img_root = '../../DragDiffusion/drag_bench_evaluation/drag_bench_data/DragBench'
# you may put more root path of your results here
evaluate_root = ['drag_diffusion_res_lora_rank_list*2_[3,4]']
for target_root in evaluate_root:
# fixing the seed for semantic correspondence
seed_everything(42)
all_dist = []
for cat in all_category:
all_dist_ = []
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
points = meta_data['points']
# here, the point is in x,y coordinate
handle_points = []
target_points = []
for idx, point in enumerate(points):
# from now on, the point is in row,col coordinate
cur_point = torch.tensor([point[1], point[0]])
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
source_image_PIL = Image.open(source_image_path)
dragged_image_PIL = Image.open(dragged_image_path)
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2
dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2
_, H, W = source_image_tensor.shape
ft_source = dift.forward(source_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')
ft_dragged = dift.forward(dragged_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear')
cos = nn.CosineSimilarity(dim=1)
for pt_idx in range(len(handle_points)):
hp = handle_points[pt_idx]
tp = target_points[pt_idx]
num_channel = ft_source.size(1)
src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)
cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W
max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col
# calculate distance
dist = (tp - torch.tensor(max_rc)).float().norm()
all_dist.append(dist)
all_dist_.append(dist)
print(cat + ' mean distance: ', torch.tensor(all_dist_).mean().item())
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# evaluate similarity between images before and after dragging
import os
from einops import rearrange
import numpy as np
import PIL
from PIL import Image
import torch
import torch.nn.functional as F
import lpips
import clip
def preprocess_image(image,
device):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device)
return image
if __name__ == '__main__':
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# lpip metric
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
# load clip model
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
original_img_root = '../../DragDiffusion/drag_bench_evaluation/drag_bench_data/DragBench'
# you may put more root path of your results here
evaluate_root = ['drag_diffusion_res_lora_rank_list*2_[3,4]']
for target_root in evaluate_root:
all_lpips = []
all_clip_sim = []
for cat in all_category:
all_lpips_ = []
all_clip_sim_ = []
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
source_image_PIL = Image.open(source_image_path)
dragged_image_PIL = Image.open(dragged_image_path)
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
source_image = preprocess_image(np.array(source_image_PIL), device)
dragged_image = preprocess_image(np.array(dragged_image_PIL), device)
# compute LPIP
with torch.no_grad():
source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear')
dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear')
cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224)
all_lpips.append(cur_lpips.item())
all_lpips_.append(cur_lpips.item())
# compute CLIP similarity
source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device)
dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device)
with torch.no_grad():
source_feature = clip_model.encode_image(source_image_clip)
dragged_feature = clip_model.encode_image(dragged_image_clip)
source_feature /= source_feature.norm(dim=-1, keepdim=True)
dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True)
cur_clip_sim = (source_feature * dragged_feature).sum()
all_clip_sim.append(cur_clip_sim.cpu().numpy())
all_clip_sim_.append(cur_clip_sim.cpu().numpy())
print(cat + ' lpips: ', np.mean(all_lpips_))
print(cat + ' clip sim', np.mean(all_clip_sim_))
print(target_root)
print('avg lpips: ', np.mean(all_lpips))
print('avg 1-lpips: ', 1.0 - np.mean(all_lpips))
print('avg clip sim', np.mean(all_clip_sim))
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
import tqdm
import sys
sys.path.insert(0, '../')
from utils.lora_utils import train_lora
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
if __name__ == '__main__':
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
# assume root_dir and lora_dir are valid directory
root_dir = '/home/bailuo/code/DragNoise/../DragDiffusion/drag_bench_evaluation/drag_bench_data/DragBench'
lora_dir = '/home/bailuo/code/DragNoise/drag_bench_evaluation/drag_bench_lora_lora_rank_list*2'
# mkdir if necessary
if not os.path.isdir(lora_dir):
os.mkdir(lora_dir)
for cat in all_category:
os.mkdir(os.path.join(lora_dir,cat))
for cat in all_category:
print(cat)
file_dir = os.path.join(root_dir, cat)
for sample_name in os.listdir(file_dir):
print(sample_name)
if sample_name == '.DS_Store':
continue
sample_path = os.path.join(file_dir, sample_name)
# read image file
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
source_image = np.array(source_image)
# load meta data
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
# train and save lora
save_lora_path = os.path.join(lora_dir, cat, sample_name)
if not os.path.isdir(save_lora_path):
os.mkdir(save_lora_path)
# you may also increase the number of lora_step here to train longer
train_lora(source_image, prompt,
model_path="botp/stable-diffusion-v1-5",
# model_path="~/.cache/huggingface/hub/models--botp--stable-diffusion-v1-5",
vae_path="default", save_lora_path=save_lora_path,
lora_step=200, lora_lr=0.0002, lora_batch_size=4, lora_rank=16, progress=tqdm, save_interval=100)
This diff is collapsed.
import os
import gradio as gr
from utils.ui_utils import get_points, undo_points
from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
from utils.ui_utils import clear_all_gen, store_img_gen
LENGTH=480 # length of the square area displaying/editing images
with gr.Blocks(title='DragNoise', theme=gr.themes.Monochrome()) as demo:
# layout definition
with gr.Row():
gr.Markdown("""
# Official Implementation of [DragNoise](https://github.com/haofengl/DragNoise)
""")
# UI components for editing images
with gr.Tab(label=''):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=2*LENGTH) # for mask painting
with gr.Column():
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
with gr.Column():
output_image = gr.Image(type="numpy", label="Editing Results",
show_label=True, height=LENGTH, width=LENGTH)
with gr.Row():
undo_button = gr.Button("Undo point")
run_button = gr.Button("Run")
# general parameters
with gr.Row():
prompt = gr.Textbox(label="Prompt")
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
lora_status_bar = gr.Textbox(label="display LoRA training status")
train_lora_button = gr.Button(value='Train LoRA', scale=0.3)
# algorithm specific parameters
with gr.Tab("Drag Config"):
with gr.Row():
n_pix_step = gr.Number(
value=60,
label="Maximum Number of Iterations",
precision=0)
inversion_strength = gr.Number(value=0.7, label='Initial Timestep')
end_step = gr.Number(value=0, label='End Timestep')
lam = gr.Number(value=0.1, label="Lambda")
latent_lr = gr.Number(value=0.01, label="Learning Rate")
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
with gr.Tab("Base Model Config"):
with gr.Row():
local_models_dir = 'local_pretrained_models'
local_models_choice = \
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
model_path = gr.Dropdown(value="botp/stable-diffusion-v1-5",
label="Diffusion Model Path",
choices=[
"botp/stable-diffusion-v1-5",
] + local_models_choice
)
vae_path = gr.Dropdown(value="default",
label="VAE choice",
choices=["default",
"stabilityai/sd-vae-ft-mse"] + local_models_choice
)
with gr.Tab("LoRA Parameters"):
with gr.Row():
lora_step = gr.Number(value=60, label="LoRA training steps", precision=0)
lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
with gr.Row():
gr.Markdown("""
# Guideline
* DragNoise features semantic editing and generally does not require the use of a Draw Mask.
* First, select the local Stable Diffusion Model in the Base Model Config section, and then train LoRA, which is essential.
* The Maximum Number of Iterations indicates the maximum iteration steps for image editing. If it's a long-distance editing operation, you can increase this value appropriately.
* The Learning Rate represents the rate of latent update during the editing process. If it's a long-distance editing operation, you can increase this value appropriately.
* Lambda represents the weight of retaining the original image during the editing process. If the editing result is significantly distorted, you can increase this value appropriately.
* The Initial Timestep indicates the degree of DDIM inversion. If you want to control significant changes in objects, you can increase this value appropriately.
* The End Timestep indicates the end time step of denoise propagation in the editing result. If you want to enhance image detail characteristics, you can increase this value to 10.
""")
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
train_lora_button.click(
train_lora_interface,
[original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank],
[lora_status_bar]
)
run_button.click(
run_drag,
[original_image,
input_image,
mask,
prompt,
selected_points,
inversion_strength,
end_step,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
],
[output_image]
)
demo.queue().launch(share=True, debug=True)
name: dragnoise
channels:
- pytorch
- defaults
dependencies:
- python=3.8.5
- pip=22.3.1
# - cudatoolkit=11.7
- pip:
- torch==2.0.0
- torchvision==0.15.1
- gradio==3.41.1
- pydantic==2.0.2
- albumentations==1.3.0
- opencv-contrib-python==4.3.0.36
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.5.0
- omegaconf==2.3.0
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.6.0
- transformers==4.27.0
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.16.0
- invisible-watermark>=0.1.5
- streamlit-drawable-canvas==0.8.0
- torchmetrics==0.6.0
- timm==0.6.12
- addict==2.4.0
- yapf==0.32.0
- prettytable==3.6.0
- safetensors==0.2.7
- basicsr==1.4.2
- accelerate==0.17.0
- decord==0.6.0
- diffusers==0.17.1
- moviepy==1.0.3
- opencv_python==4.7.0.68
- Pillow==9.4.0
- scikit_image==0.19.3
- scipy==1.10.1
- tensorboardX==2.6
- tqdm==4.64.1
- numpy==1.24.1
icon.png

68.4 KB

Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment