Commit 7dc08a7d authored by bailuo's avatar bailuo
Browse files

init

parents
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
Apache DragDiffusion Subcomponents:
The Apache DragDiffusion project contains subcomponents with separate copyright
notices and license terms. Your use of the source code for the these
subcomponents is subject to the terms and conditions of the following
licenses.
========================================================================
Apache 2.0 licenses
========================================================================
The following components are provided under the Apache License. See project link for details.
The text of each license is the standard Apache 2.0 license.
files from lora: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py apache 2.0
\ No newline at end of file
# DragDiffusion
DragDiffusion 模型,利用扩散模型进行基于点的交互式图像编辑,允许用户将图像中的任意点“拖动”到目标位置,以精确控制姿势、形状、表情和布局。
## 论文
`DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing`
- https://arxiv.org/abs/2306.14435
- CVPR 2024
## 模型结构
<!-- 此处一句话简要介绍模型结构 -->
<!-- DragDiffusion利用扩散模型进行基于点的交互式图像编辑。主要是基于 StyleGAN 模型架构: -->
<div align=center>
<img src="./doc/overview.png"/>
<div >DragDiffusion</div>
</div>
## 算法原理
DragDiffusion 算法受 DragGAN 的启发,把编解码的图像重建部分利用上大规模预训练扩散模型,极大提升了基于点的交互式编辑在现实世界场景中的适用性。
(1)先通过LoRA微调SD模型,数据集为用户输入的图像。目的是在编辑过程中(其实也是生成过程)更好的保留输入图像中物体和风格特征。\
(2)通过运动监督(Motion Supervision)和点跟踪(Point Tracking)实现对扩散 latent 进行优化,确保多步迭代过程中更加的精准和有效。\
(3)在最后一步去噪的过程中为了保证统一以及质量,从 MasaCtrl 中汲取灵感,提出利用自注意力模块的属性来引导去噪过程。
在编辑过程中,需要增加正则项确保非编辑区域(编辑mask区域外)不变。
<!-- <div align=center>
<img src="./doc/pipeline.png"/>
<div >DragDiffusion</div>
</div> -->
## 环境配置
```
mv dragdiffusion_pytoch dragdiffusion # 去框架名后缀
# docker的-v 路径、docker_name和imageID根据实际情况修改
# pip安装时如果出现下载慢可以尝试别的镜像源
```
### Docker(方法一)
<!-- 此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤 -->
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 # 本镜像imageID为:2f1f619d0182
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/dragdiffusion
pip install -r requirements.txt
```
### Dockerfile(方法二)
<!-- 此处提供dockerfile的使用方法 -->
```
cd /your_code_path/dragdiffusion/docker
docker build --no-cache -t codestral:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/dragdiffusion
pip install -r requirements.txt
```
### Anaconda(方法三)
<!-- 此处提供本地配置、编译的详细步骤,例如: -->
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动: dtk24.04.2
python: python3.10
pytorch: 2.1.0
```
`Tips:以上DTK驱动、python、pytorch等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 数据集
测试数据集 [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) 或者从 [`SCNet`](http://113.200.138.88:18080/aidatasets/dragbench) 上下载。\
下载后放在 ./drag_bench_evaluation/drag_bench_data 并解压,文件构成:
<br>
DragBench<br>
--- animals<br>
------ JH_2023-09-14-1820-16<br>
------ JH_2023-09-14-1821-23<br>
------ JH_2023-09-14-1821-58<br>
------ ...<br>
--- art_work<br>
--- building_city_view<br>
--- ...<br>
--- other_objects<br>
<br>
## 训练
推理中有一步LoRA微调,详情见webui。
## 推理
<!-- 下载模型权重:
```
python scripts/download_model.py
```
或者从 [SCNet](http://113.200.138.88:18080/aimodels/findsource-dependency/stylegan2_pytorch) 上快速下载,并放在 /checkpoints 文件夹下。 -->
可视化webui推理:
```
python drag_ui.py --listen
```
<div align=center>
<img src="./doc/webui.png" width=600/>
<div >webui界面</div>
</div>
1、上传图片;\
2、输入提示;\
3、LoRA训练;\
4、通过鼠标选择要编辑的区域;\
5、通过鼠标标记点位;\
6、运行。\
ps:Drag以及LoRA的一些参数自行视情况修改。
## result
<!-- 此处填算法效果测试图(包括输入、输出) -->
<div align=center>
<img src="./doc/result1.png" width=600/>
<div >推理结果</div>
</div>
<!-- <div align=center>
<img src="./doc/image (1).png" width=600/>
<div >输出</div>
</div> -->
### 精度
测试集 `DragBench`,如上所述下载并解压好。
```
python run_lora_training.py
python run_drag_diffusion.py
python run_eval_similarity.py
# ps:上述脚本的一些文件路径自行根据情况修改
```
| 加速卡 | lpips | clip sim |
| :-----| :----- | :---- |
| K100_AI | 0.115 | 0.977 |
<!-- | 单元格 | 单元格 | 单元格 | -->
## 应用场景
### 算法类别
<!-- 超出以上分类的类别命名也可参考此网址中的类别名:https://huggingface.co/ \ -->
`AIGC`
### 热点应用行业
<!-- 应用行业的填写需要做大量调研,从而为使用者提供专业、全面的推荐,除特殊算法,通常推荐数量>=3。 -->
`零售,制造,电商,医疗,教育`
<!-- ## 预训练权重 -->
<!-- - 此处填写预训练权重在公司内部的下载地址(预训练权重存放中心为:[SCNet AIModels](http://113.200.138.88:18080/aimodels) ,模型用到的各预训练权重请分别填上具体地址。),过小权重文件可打包到项目里。
- 此处填写公开预训练权重官网下载地址(非必须)。 -->
## 源码仓库及问题反馈
<!-- - 此处填本项目gitlab地址 -->
- https://developer.sourcefind.cn/codes/modelzoo/dragdiffusion_pytorch
## 参考资料
- https://github.com/XingangPan/DragDiffusion
<p align="center">
<h1 align="center">DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing</h1>
<p align="center">
<a href="https://yujun-shi.github.io/"><strong>Yujun Shi</strong></a>
&nbsp;&nbsp;
<strong>Chuhui Xue</strong>
&nbsp;&nbsp;
<strong>Jun Hao Liew</strong>
&nbsp;&nbsp;
<strong>Jiachun Pan</strong>
&nbsp;&nbsp;
<br>
<strong>Hanshu Yan</strong>
&nbsp;&nbsp;
<strong>Wenqing Zhang</strong>
&nbsp;&nbsp;
<a href="https://vyftan.github.io/"><strong>Vincent Y. F. Tan</strong></a>
&nbsp;&nbsp;
<a href="https://songbai.site/"><strong>Song Bai</strong></a>
</p>
<br>
<div align="center">
<img src="./release-doc/asset/counterfeit-1.png", width="700">
<img src="./release-doc/asset/counterfeit-2.png", width="700">
<img src="./release-doc/asset/majix_realistic.png", width="700">
</div>
<div align="center">
<img src="./release-doc/asset/github_video.gif", width="700">
</div>
<p align="center">
<a href="https://arxiv.org/abs/2306.14435"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2306.14435-b31b1b.svg"></a>
<a href="https://yujun-shi.github.io/projects/dragdiffusion.html"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
<a href="https://twitter.com/YujunPeiyangShi"><img alt='Twitter' src="https://img.shields.io/twitter/follow/YujunPeiyangShi?label=%40YujunPeiyangShi"></a>
</p>
<br>
</p>
## Disclaimer
This is a research project, NOT a commercial product. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.
## News and Update
* [Jan 29th] Update to support diffusers==0.24.0!
* [Oct 23rd] Code and data of DragBench are released! Please check README under "drag_bench_evaluation" for details.
* [Oct 16th] Integrate [FreeU](https://chenyangsi.top/FreeU/) when dragging generated image.
* [Oct 3rd] Speeding up LoRA training when editing real images. (**Now only around 20s on A100!**)
* [Sept 3rd] v0.1.0 Release.
* Enable **Dragging Diffusion-Generated Images.**
* Introducing a new guidance mechanism that **greatly improve quality of dragging results.** (Inspired by [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/))
* Enable Dragging Images with arbitrary aspect ratio
* Adding support for DPM++Solver (Generated Images)
* [July 18th] v0.0.1 Release.
* Integrate LoRA training into the User Interface. **No need to use training script and everything can be conveniently done in UI!**
* Optimize User Interface layout.
* Enable using better VAE for eyes and faces (See [this](https://stable-diffusion-art.com/how-to-use-vae/))
* [July 8th] v0.0.0 Release.
* Implement Basic function of DragDiffusion
## Installation
It is recommended to run our code on a Nvidia GPU with a linux system. We have not yet tested on other configurations. Currently, it requires around 14 GB GPU memory to run our method. We will continue to optimize memory efficiency
To install the required libraries, simply run the following command:
```
conda env create -f environment.yaml
conda activate dragdiff
```
## Run DragDiffusion
To start with, in command line, run the following to start the gradio user interface:
```
python3 drag_ui.py
```
You may check our [GIF above](https://github.com/Yujun-Shi/DragDiffusion/blob/main/release-doc/asset/github_video.gif) that demonstrate the usage of UI in a step-by-step manner.
Basically, it consists of the following steps:
### Case 1: Dragging Input Real Images
#### 1) train a LoRA
* Drop our input image into the left-most box.
* Input a prompt describing the image in the "prompt" field
* Click the "Train LoRA" button to train a LoRA given the input image
#### 2) do "drag" editing
* Draw a mask in the left-most box to specify the editable areas.
* Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point".
* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
### Case 2: Dragging Diffusion-Generated Images
#### 1) generate an image
* Fill in the generation parameters (e.g., positive/negative prompt, parameters under Generation Config & FreeU Parameters).
* Click "Generate Image".
#### 2) do "drag" on the generated image
* Draw a mask in the left-most box to specify the editable areas
* Click handle points and target points in the middle box.
* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
<!---
## Explanation for parameters in the user interface:
#### General Parameters
|Parameter|Explanation|
|-----|------|
|prompt|The prompt describing the user input image (This will be used to train the LoRA and conduct "drag" editing).|
|lora_path|The directory where the trained LoRA will be saved.|
#### Algorithm Parameters
These parameters are collapsed by default as we normally do not have to tune them. Here are the explanations:
* Base Model Config
|Parameter|Explanation|
|-----|------|
|Diffusion Model Path|The path to the diffusion models. By default we are using "botp/stable-diffusion-v1-5". We will add support for more models in the future.|
|VAE Choice|The Choice of VAE. Now there are two choices, one is "default", which will use the original VAE. Another choice is "stabilityai/sd-vae-ft-mse", which can improve results on images with human eyes and faces (see [explanation](https://stable-diffusion-art.com/how-to-use-vae/))|
* Drag Parameters
|Parameter|Explanation|
|-----|------|
|n_pix_step|Maximum number of steps of motion supervision. **Increase this if handle points have not been "dragged" to desired position.**|
|lam|The regularization coefficient controlling unmasked region stays unchanged. Increase this value if the unmasked region has changed more than what was desired (do not have to tune in most cases).|
|n_actual_inference_step|Number of DDIM inversion steps performed (do not have to tune in most cases).|
* LoRA Parameters
|Parameter|Explanation|
|-----|------|
|LoRA training steps|Number of LoRA training steps (do not have to tune in most cases).|
|LoRA learning rate|Learning rate of LoRA (do not have to tune in most cases)|
|LoRA rank|Rank of the LoRA (do not have to tune in most cases).|
--->
## License
Code related to the DragDiffusion algorithm is under Apache 2.0 license.
## BibTeX
If you find our repo helpful, please consider leaving a star or cite our paper :)
```bibtex
@article{shi2023dragdiffusion,
title={DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing},
author={Shi, Yujun and Xue, Chuhui and Pan, Jiachun and Zhang, Wenqing and Tan, Vincent YF and Bai, Song},
journal={arXiv preprint arXiv:2306.14435},
year={2023}
}
```
## Contact
For any questions on this project, please contact [Yujun](https://yujun-shi.github.io/) (shi.yujun@u.nus.edu)
## Acknowledgement
This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). The lora training code is modified from an [example](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py) of diffusers. Image samples are collected from [unsplash](https://unsplash.com/), [pexels](https://www.pexels.com/zh-cn/), [pixabay](https://pixabay.com/). Finally, a huge shout-out to all the amazing open source diffusion models and libraries.
## Related Links
* [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
* [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/)
* [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/)
* [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/)
* [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/)
## Common Issues and Solutions
1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path".
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10
ENV DEBIAN_FRONTEND=noninteractive
# COPY requirements.txt requirements.txt
# RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
# How to Evaluate with DragBench
### Step 1: extract dataset
Extract [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) into the folder "drag_bench_data".
Resulting directory hierarchy should look like the following:
<br>
drag_bench_data<br>
--- animals<br>
------ JH_2023-09-14-1820-16<br>
------ JH_2023-09-14-1821-23<br>
------ JH_2023-09-14-1821-58<br>
------ ...<br>
--- art_work<br>
--- building_city_view<br>
--- ...<br>
--- other_objects<br>
<br>
### Step 2: train LoRA.
Train one LoRA on each image in drag_bench_data.
To do this, simply execute "run_lora_training.py".
Trained LoRAs will be saved in "drag_bench_lora"
### Step 3: run dragging results
To run dragging results of DragDiffusion on images in "drag_bench_data", simply execute "run_drag_diffusion.py".
Results will be saved in "drag_diffusion_res".
### Step 4: evaluate mean distance and similarity.
To evaluate LPIPS score before and after dragging, execute "run_eval_similarity.py"
To evaluate mean distance between target points and the final position of handle points (estimated by DIFT), execute "run_eval_point_matching.py"
# Expand the Dataset
Here we also provided the labeling tool used by us in the file "labeling_tool.py".
Run this file to get the user interface for labeling your images with drag instructions.
\ No newline at end of file
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import numpy as np
import pickle
import sys
sys.path.insert(0, '../')
if __name__ == '__main__':
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
# assume root_dir and lora_dir are valid directory
root_dir = 'drag_bench_data'
num_samples, num_pair_points = 0, 0
for cat in all_category:
file_dir = os.path.join(root_dir, cat)
for sample_name in os.listdir(file_dir):
if sample_name == '.DS_Store':
continue
sample_path = os.path.join(file_dir, sample_name)
# load meta data
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
points = meta_data['points']
num_samples += 1
num_pair_points += len(points) // 2
print(num_samples)
print(num_pair_points)
\ No newline at end of file
# code credit: https://github.com/Tsingularity/dift/blob/main/src/models/dift_sd.py
from diffusers import StableDiffusionPipeline
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import DDIMScheduler
import gc
from PIL import Image
class MyUNet2DConditionModel(UNet2DConditionModel):
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
up_ft_indices,
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
# 5. up
up_ft = {}
for i, upsample_block in enumerate(self.up_blocks):
if i > np.max(up_ft_indices):
break
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
if i in up_ft_indices:
up_ft[i] = sample.detach()
output = {}
output['up_ft'] = up_ft
return output
class OneStepSDPipeline(StableDiffusionPipeline):
@torch.no_grad()
def __call__(
self,
img_tensor,
t,
up_ft_indices,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None
):
device = self._execution_device
latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
t = torch.tensor(t, dtype=torch.long, device=device)
noise = torch.randn_like(latents).to(device)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
unet_output = self.unet(latents_noisy,
t,
up_ft_indices,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs)
return unet_output
class SDFeaturizer:
def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):
unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
onestep_pipe.vae.decoder = None
onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
gc.collect()
onestep_pipe = onestep_pipe.to("cuda")
onestep_pipe.enable_attention_slicing()
# onestep_pipe.enable_xformers_memory_efficient_attention()
self.pipe = onestep_pipe
@torch.no_grad()
def forward(self,
img_tensor,
prompt,
t=261,
up_ft_index=1,
ensemble_size=8):
'''
Args:
img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
prompt: the prompt to use, a string
t: the time step to use, should be an int in the range of [0, 1000]
up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
ensemble_size: the number of repeated images used in the batch to extract features
Return:
unet_ft: a torch tensor in the shape of [1, c, h, w]
'''
img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
prompt_embeds = self.pipe._encode_prompt(
prompt=prompt,
device='cuda',
num_images_per_prompt=1,
do_classifier_free_guidance=False) # [1, 77, dim]
prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
unet_ft_all = self.pipe(
img_tensor=img_tensor,
t=t,
up_ft_indices=[up_ft_index],
prompt_embeds=prompt_embeds)
unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
return unet_ft
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import cv2
import numpy as np
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import os
import gradio as gr
import datetime
import pickle
from copy import deepcopy
LENGTH=480 # length of the square area displaying/editing images
def clear_all(length=480):
return gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
[], None, None
def mask_image(image,
mask,
color=[255,0,0],
alpha=0.5):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
return out
def store_img(img, length=512):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
height,width,_ = image.shape
image = Image.fromarray(image)
image = exif_transpose(image)
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# user click the image to get points, and show the points on the image
def get_points(img,
sel_pix,
evt: gr.SelectData):
# collect the selected point
sel_pix.append(evt.index)
# draw points
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
# draw a red circle at the handle point
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
else:
# draw a blue circle at the handle point
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
points.append(tuple(point))
# draw an arrow from handle point to target point
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
points = []
return img if isinstance(img, np.ndarray) else np.array(img)
# clear all handle/target points
def undo_points(original_image,
mask):
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = original_image.copy()
return masked_img, []
def save_all(category,
source_image,
image_with_clicks,
mask,
labeler,
prompt,
points,
root_dir='./drag_bench_data'):
if not os.path.isdir(root_dir):
os.mkdir(root_dir)
if not os.path.isdir(os.path.join(root_dir, category)):
os.mkdir(os.path.join(root_dir, category))
save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_dir = os.path.join(root_dir, category, save_prefix)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
# save images
Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png'))
Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png'))
# save meta data
meta_data = {
'prompt' : prompt,
'points' : points,
'mask' : mask,
}
with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f:
pickle.dump(meta_data, f)
return save_prefix + " saved!"
with gr.Blocks() as demo:
# UI components for editing real images
with gr.Tab(label="Editing Real Image"):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
with gr.Row():
labeler = gr.Textbox(label="Labeler")
category = gr.Dropdown(value="art_work",
label="Image Category",
choices=[
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
)
prompt = gr.Textbox(label="Prompt")
save_status = gr.Textbox(label="display saving status")
with gr.Row():
undo_button = gr.Button("undo points")
clear_all_button = gr.Button("clear all")
save_button = gr.Button("save")
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
clear_all_button.click(
clear_all,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas,
input_image,
selected_points,
original_image,
mask]
)
save_button.click(
save_all,
[category,
original_image,
input_image,
mask,
labeler,
prompt,
selected_points,],
[save_status]
)
demo.queue().launch(share=True, debug=True)
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# run results of DragDiffusion
import argparse
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
from diffusers import DDIMScheduler, AutoencoderKL
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
import sys
sys.path.insert(0, '../')
from drag_pipeline import DragPipeline
from utils.drag_utils import drag_diffusion_update
from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
def preprocess_image(image,
device):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device)
return image
# copy the run_drag function to here
def run_drag(source_image,
# image_with_clicks,
mask,
prompt,
points,
inversion_strength,
lam,
latent_lr,
unet_feature_idx,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
# save_dir="./results"
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# initialize parameters
seed = 42 # random seed used by a lot of people for unknown reason
seed_everything(seed)
args = SimpleNamespace()
args.prompt = prompt
args.points = points
args.n_inference_step = 50
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0
args.unet_feature_idx = [unet_feature_idx]
args.r_m = 1
args.r_p = 3
args.lam = lam
args.lr = latent_lr
args.n_pix_step = n_pix_step
full_h, full_w = source_image.shape[:2]
args.sup_res_h = int(0.5*full_h)
args.sup_res_w = int(0.5*full_w)
print(args)
source_image = preprocess_image(source_image, device)
# image_with_clicks = preprocess_image(image_with_clicks, device)
# set lora
if lora_path == "":
print("applying default parameters")
model.unet.set_default_attn_processor()
else:
print("applying lora: " + lora_path)
model.unet.load_attn_procs(lora_path)
# invert the source image
# the latent code resolution is too small, only 64*64
invert_code = model.invert(source_image,
prompt,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step)
mask = torch.from_numpy(mask).float() / 255.
mask[mask > 0.0] = 1.0
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
handle_points = []
target_points = []
# here, the point is in x,y coordinate
for idx, point in enumerate(points):
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
cur_point = torch.round(cur_point)
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
print('handle points:', handle_points)
print('target points:', target_points)
init_code = invert_code
init_code_orig = deepcopy(init_code)
model.scheduler.set_timesteps(args.n_inference_step)
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
updated_init_code = drag_diffusion_update(model, init_code,
None, t, handle_points, target_points, mask, args)
# hijack the attention module
# inject the reference branch to guide the generation
editor = MutualSelfAttentionControl(start_step=start_step,
start_layer=start_layer,
total_steps=args.n_inference_step,
guidance_scale=args.guidance_scale)
if lora_path == "":
register_attention_editor_diffusers(model, editor, attn_processor='attn_proc')
else:
register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc')
# inference the synthesized image
gen_image = model(
prompt=args.prompt,
batch_size=2,
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)[1].unsqueeze(dim=0)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
# save the original image, user editing instructions, synthesized image
# save_result = torch.cat([
# source_image * 0.5 + 0.5,
# torch.ones((1,3,full_h,25)).cuda(),
# image_with_clicks * 0.5 + 0.5,
# torch.ones((1,3,full_h,25)).cuda(),
# gen_image[0:1]
# ], dim=-1)
# if not os.path.isdir(save_dir):
# os.mkdir(save_dir)
# save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
# save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="setting arguments")
parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps')
parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength')
parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
args = parser.parse_args()
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
# assume root_dir and lora_dir are valid directory
root_dir = 'drag_bench_data/DragBench'
lora_dir = 'drag_bench_lora'
result_dir = 'drag_diffusion_res' + \
'_' + str(args.lora_steps) + \
'_' + str(args.inv_strength) + \
'_' + str(args.latent_lr) + \
'_' + str(args.unet_feature_idx)
# mkdir if necessary
if not os.path.isdir(result_dir):
os.mkdir(result_dir)
for cat in all_category:
os.mkdir(os.path.join(result_dir,cat))
for cat in all_category:
file_dir = os.path.join(root_dir, cat)
for sample_name in os.listdir(file_dir):
if sample_name == '.DS_Store':
continue
sample_path = os.path.join(file_dir, sample_name)
# read image file
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
source_image = np.array(source_image)
# load meta data
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
mask = meta_data['mask']
points = meta_data['points']
# load lora
lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps))
print("applying lora: " + lora_path)
out_image = run_drag(
source_image,
mask,
prompt,
points,
inversion_strength=args.inv_strength,
lam=0.1,
latent_lr=args.latent_lr,
unet_feature_idx=args.unet_feature_idx,
n_pix_step=80,
model_path="botp/stable-diffusion-v1-5",
vae_path="default",
lora_path=lora_path,
start_step=0,
start_layer=10,
)
save_dir = os.path.join(result_dir, cat, sample_name)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# run evaluation of mean distance between the desired target points and the position of final handle points
import argparse
import os
import pickle
import numpy as np
import PIL
from PIL import Image
from torchvision.transforms import PILToTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from dift_sd import SDFeaturizer
from pytorch_lightning import seed_everything
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="setting arguments")
parser.add_argument('--eval_root',
action='append',
help='root of dragging results for evaluation',
required=True)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# using SD-2.1
dift = SDFeaturizer('stabilityai/stable-diffusion-2-1')
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
original_img_root = 'drag_bench_data/'
for target_root in args.eval_root:
# fixing the seed for semantic correspondence
seed_everything(42)
all_dist = []
for cat in all_category:
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
points = meta_data['points']
# here, the point is in x,y coordinate
handle_points = []
target_points = []
for idx, point in enumerate(points):
# from now on, the point is in row,col coordinate
cur_point = torch.tensor([point[1], point[0]])
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
source_image_PIL = Image.open(source_image_path)
dragged_image_PIL = Image.open(dragged_image_path)
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2
dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2
_, H, W = source_image_tensor.shape
ft_source = dift.forward(source_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')
ft_dragged = dift.forward(dragged_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear')
cos = nn.CosineSimilarity(dim=1)
for pt_idx in range(len(handle_points)):
hp = handle_points[pt_idx]
tp = target_points[pt_idx]
num_channel = ft_source.size(1)
src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)
cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W
max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col
# calculate distance
dist = (tp - torch.tensor(max_rc)).float().norm()
all_dist.append(dist)
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# evaluate similarity between images before and after dragging
import argparse
import os
from einops import rearrange
import numpy as np
import PIL
from PIL import Image
import torch
import torch.nn.functional as F
import lpips
import clip
def preprocess_image(image,
device):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="setting arguments")
parser.add_argument('--eval_root',
action='append',
help='root of dragging results for evaluation',
required=True)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# lpip metric
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
# load clip model
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
original_img_root = 'drag_bench_data/DragBench'
for target_root in args.eval_root:
all_lpips = []
all_clip_sim = []
for cat in all_category:
all_lpips_ = []
all_clip_sim_ = []
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
source_image_PIL = Image.open(source_image_path)
dragged_image_PIL = Image.open(dragged_image_path)
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
source_image = preprocess_image(np.array(source_image_PIL), device)
dragged_image = preprocess_image(np.array(dragged_image_PIL), device)
# compute LPIP
with torch.no_grad():
source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear')
dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear')
cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224)
all_lpips.append(cur_lpips.item())
all_lpips_.append(cur_lpips.item())
# compute CLIP similarity
source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device)
dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device)
with torch.no_grad():
source_feature = clip_model.encode_image(source_image_clip)
dragged_feature = clip_model.encode_image(dragged_image_clip)
source_feature /= source_feature.norm(dim=-1, keepdim=True)
dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True)
cur_clip_sim = (source_feature * dragged_feature).sum()
all_clip_sim.append(cur_clip_sim.cpu().numpy())
all_clip_sim_.append(cur_clip_sim.cpu().numpy())
print(cat)
print('avg lpips: ', np.mean(all_lpips_))
print('avg clip sim', np.mean(all_clip_sim_))
print(target_root)
print('avg lpips: ', np.mean(all_lpips))
print('avg clip sim', np.mean(all_clip_sim))
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
import tqdm
import sys
sys.path.insert(0, '../')
from utils.lora_utils import train_lora
if __name__ == '__main__':
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
# assume root_dir and lora_dir are valid directory
root_dir = 'drag_bench_data/DragBench'
lora_dir = 'drag_bench_lora'
# mkdir if necessary
if not os.path.isdir(lora_dir):
os.mkdir(lora_dir)
for cat in all_category:
os.mkdir(os.path.join(lora_dir,cat))
for cat in all_category:
file_dir = os.path.join(root_dir, cat)
for sample_name in os.listdir(file_dir):
if sample_name == '.DS_Store':
continue
sample_path = os.path.join(file_dir, sample_name)
# read image file
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
source_image = np.array(source_image)
# load meta data
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
# train and save lora
save_lora_path = os.path.join(lora_dir, cat, sample_name)
if not os.path.isdir(save_lora_path):
os.mkdir(save_lora_path)
# you may also increase the number of lora_step here to train longer
train_lora(source_image, prompt,
model_path="botp/stable-diffusion-v1-5",
vae_path="default", save_lora_path=save_lora_path,
lora_step=80, lora_lr=0.0005, lora_batch_size=4, lora_rank=16, progress=tqdm, save_interval=10)
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import StableDiffusionPipeline
# override unet forward
# The only difference from diffusers:
# return intermediate UNet features of all UpSample blocks
def override_forward(self):
def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_intermediates: bool = False,
):
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image":
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# if USE_PEFT_BACKEND:
# # weight the lora layers by setting `lora_scale` for each PEFT layer
# scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
all_intermediate_features = [sample]
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
all_intermediate_features.append(sample)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# if USE_PEFT_BACKEND:
# # remove `lora_scale` from each PEFT layer
# unscale_lora_layers(self, lora_scale)
# only difference from diffusers, return intermediate results
if return_intermediates:
return sample, all_intermediate_features
else:
return sample
return forward
class DragPipeline(StableDiffusionPipeline):
# must call this function when initialize
def modify_unet_forward(self):
self.unet.forward = override_forward(self.unet)
def inv_step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta=0.,
verbose=False
):
"""
Inverse sampling for DDIM Inversion
"""
if verbose:
print("timestep: ", timestep)
next_step = timestep
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
return x_next, pred_x0
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
):
"""
predict the sample of the next step in the denoise process.
"""
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
return x_prev, pred_x0
@torch.no_grad()
def image2latent(self, image):
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if type(image) is Image:
image = np.array(image)
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# input image density range [-1, 1]
latents = self.vae.encode(image)['latent_dist'].mean
latents = latents * 0.18215
return latents
@torch.no_grad()
def latent2image(self, latents, return_type='np'):
latents = 1 / 0.18215 * latents.detach()
image = self.vae.decode(latents)['sample']
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
elif return_type == "pt":
image = (image / 2 + 0.5).clamp(0, 1)
return image
def latent2image_grad(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)['sample']
return image # range [-1, 1]
@torch.no_grad()
def get_text_embeddings(self, prompt):
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=77,
return_tensors="pt"
)
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
return text_embeddings
# get all intermediate features and then do bilinear interpolation
# return features in the layer_idx list
def forward_unet_features(
self,
z,
t,
encoder_hidden_states,
layer_idx=[0],
interp_res_h=256,
interp_res_w=256):
unet_output, all_intermediate_features = self.unet(
z,
t,
encoder_hidden_states=encoder_hidden_states,
return_intermediates=True
)
all_return_features = []
for idx in layer_idx:
feat = all_intermediate_features[idx]
feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear')
all_return_features.append(feat)
return_features = torch.cat(all_return_features, dim=1)
return unet_output, return_features
@torch.no_grad()
def __call__(
self,
prompt,
encoder_hidden_states=None,
batch_size=1,
height=512,
width=512,
num_inference_steps=50,
num_actual_inference_steps=None,
guidance_scale=7.5,
latents=None,
neg_prompt=None,
return_intermediates=False,
**kwds):
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if encoder_hidden_states is None:
if isinstance(prompt, list):
batch_size = len(prompt)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
# text embeddings
encoder_hidden_states = self.get_text_embeddings(prompt)
# define initial latents if not predefined
if latents is None:
latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype)
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
if neg_prompt:
uc_text = neg_prompt
else:
uc_text = ""
unconditional_embeddings = self.get_text_embeddings([uc_text]*batch_size)
encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
print("latents shape: ", latents.shape)
# iterative sampling
self.scheduler.set_timesteps(num_inference_steps)
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
if return_intermediates:
latents_list = [latents]
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
continue
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
# predict the noise
noise_pred = self.unet(
model_inputs,
t,
encoder_hidden_states=encoder_hidden_states,
)
if guidance_scale > 1.0:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if return_intermediates:
latents_list.append(latents)
image = self.latent2image(latents, return_type="pt")
if return_intermediates:
return image, latents_list
return image
@torch.no_grad()
def invert(
self,
image: torch.Tensor,
prompt,
encoder_hidden_states=None,
num_inference_steps=50,
num_actual_inference_steps=None,
guidance_scale=7.5,
eta=0.0,
return_intermediates=False,
**kwds):
"""
invert a real image into noise map with determinisc DDIM inversion
"""
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = image.shape[0]
if encoder_hidden_states is None:
if isinstance(prompt, list):
if batch_size == 1:
image = image.expand(len(prompt), -1, -1, -1)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
encoder_hidden_states = self.get_text_embeddings(prompt)
# define initial latents
latents = self.image2latent(image)
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
max_length = text_input.input_ids.shape[-1]
unconditional_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=77,
return_tensors="pt"
)
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
print("latents shape: ", latents.shape)
# interative sampling
self.scheduler.set_timesteps(num_inference_steps)
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
# print("attributes: ", self.scheduler.__dict__)
latents_list = [latents]
pred_x0_list = [latents]
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
continue
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
# predict the noise
noise_pred = self.unet(model_inputs,
t,
encoder_hidden_states=encoder_hidden_states,
)
if guidance_scale > 1.:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t-1 -> x_t
latents, pred_x0 = self.inv_step(noise_pred, t, latents)
latents_list.append(latents)
pred_x0_list.append(pred_x0)
if return_intermediates:
# return the intermediate laters during inversion
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
return latents, latents_list
return latents
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import gradio as gr
from utils.ui_utils import get_points, undo_points
from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen
LENGTH=480 # length of the square area displaying/editing images
with gr.Blocks() as demo:
# layout definition
with gr.Row():
gr.Markdown("""
# Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435)
""")
# UI components for editing real images
with gr.Tab(label="Editing Real Image"):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
train_lora_button = gr.Button("Train LoRA")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
undo_button = gr.Button("Undo point")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
output_image = gr.Image(type="numpy", label="Editing Results",
show_label=True, height=LENGTH, width=LENGTH, interactive=False)
with gr.Row():
run_button = gr.Button("Run")
clear_all_button = gr.Button("Clear All")
# general parameters
with gr.Row():
prompt = gr.Textbox(label="Prompt")
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
lora_status_bar = gr.Textbox(label="display LoRA training status")
# algorithm specific parameters
with gr.Tab("Drag Config"):
with gr.Row():
n_pix_step = gr.Number(
value=80,
label="number of pixel steps",
info="Number of gradient descent (motion supervision) steps on latent.",
precision=0)
lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
# n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0)
inversion_strength = gr.Slider(0, 1.0,
value=0.7,
label="inversion strength",
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
latent_lr = gr.Number(value=0.01, label="latent lr")
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
with gr.Tab("Base Model Config"):
with gr.Row():
local_models_dir = 'local_pretrained_models'
local_models_choice = \
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
model_path = gr.Dropdown(value="botp/stable-diffusion-v1-5",
label="Diffusion Model Path",
choices=[
"botp/stable-diffusion-v1-5",
"gsdf/Counterfeit-V2.5",
"stablediffusionapi/anything-v5",
"SG161222/Realistic_Vision_V2.0",
] + local_models_choice
)
vae_path = gr.Dropdown(value="default",
label="VAE choice",
choices=["default",
"stabilityai/sd-vae-ft-mse"] + local_models_choice
)
with gr.Tab("LoRA Parameters"):
with gr.Row():
lora_step = gr.Number(value=80, label="LoRA training steps", precision=0)
lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
# UI components for editing generated images
with gr.Tab(label="Editing Generated Image"):
mask_gen = gr.State(value=None) # store mask
selected_points_gen = gr.State([]) # store points
original_image_gen = gr.State(value=None) # store the diffusion-generated image
intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for mask painting
gen_img_button = gr.Button("Generate Image")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image_gen = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
undo_button_gen = gr.Button("Undo point")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
output_image_gen = gr.Image(type="numpy", label="Editing Results",
show_label=True, height=LENGTH, width=LENGTH, interactive=False)
with gr.Row():
run_button_gen = gr.Button("Run")
clear_all_button_gen = gr.Button("Clear All")
# general parameters
with gr.Row():
pos_prompt_gen = gr.Textbox(label="Positive Prompt")
neg_prompt_gen = gr.Textbox(label="Negative Prompt")
with gr.Tab("Generation Config"):
with gr.Row():
local_models_dir = 'local_pretrained_models'
local_models_choice = \
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
model_path_gen = gr.Dropdown(value="botp/stable-diffusion-v1-5",
label="Diffusion Model Path",
choices=[
"botp/stable-diffusion-v1-5",
"gsdf/Counterfeit-V2.5",
"emilianJR/majicMIX_realistic_v6",
"SG161222/Realistic_Vision_V2.0",
"stablediffusionapi/anything-v5",
"stablediffusionapi/interiordesignsuperm",
"stablediffusionapi/dvarch",
] + local_models_choice
)
vae_path_gen = gr.Dropdown(value="default",
label="VAE choice",
choices=["default",
"stabilityai/sd-vae-ft-mse"] + local_models_choice
)
lora_path_gen = gr.Textbox(value="", label="LoRA path")
gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0)
height = gr.Number(value=512, label="Height", precision=0)
width = gr.Number(value=512, label="Width", precision=0)
guidance_scale = gr.Number(value=7.5, label="CFG Scale")
scheduler_name_gen = gr.Dropdown(
value="DDIM",
label="Scheduler",
choices=[
"DDIM",
"DPM++2M",
"DPM++2M_karras"
]
)
n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0)
with gr.Tab("FreeU Parameters"):
with gr.Row():
b1_gen = gr.Slider(label='b1',
info='1st stage backbone factor',
minimum=1,
maximum=1.6,
step=0.05,
value=1.0)
b2_gen = gr.Slider(label='b2',
info='2nd stage backbone factor',
minimum=1,
maximum=1.6,
step=0.05,
value=1.0)
s1_gen = gr.Slider(label='s1',
info='1st stage skip factor',
minimum=0,
maximum=1,
step=0.05,
value=1.0)
s2_gen = gr.Slider(label='s2',
info='2nd stage skip factor',
minimum=0,
maximum=1,
step=0.05,
value=1.0)
with gr.Tab(label="Drag Config"):
with gr.Row():
n_pix_step_gen = gr.Number(
value=80,
label="Number of Pixel Steps",
info="Number of gradient descent (motion supervision) steps on latent.",
precision=0)
lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
# n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0)
inversion_strength_gen = gr.Slider(0, 1.0,
value=0.7,
label="Inversion Strength",
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
latent_lr_gen = gr.Number(value=0.01, label="latent lr")
start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
train_lora_button.click(
train_lora_interface,
[original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank],
[lora_status_bar]
)
run_button.click(
run_drag,
[original_image,
input_image,
mask,
prompt,
selected_points,
inversion_strength,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
],
[output_image]
)
clear_all_button.click(
clear_all,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas,
input_image,
output_image,
selected_points,
original_image,
mask]
)
# event for dragging generated image
canvas_gen.edit(
store_img_gen,
[canvas_gen],
[original_image_gen, selected_points_gen, input_image_gen, mask_gen]
)
input_image_gen.select(
get_points,
[input_image_gen, selected_points_gen],
[input_image_gen],
)
gen_img_button.click(
gen_img,
[
gr.Number(value=LENGTH, visible=False, precision=0),
height,
width,
n_inference_step_gen,
scheduler_name_gen,
gen_seed,
guidance_scale,
pos_prompt_gen,
neg_prompt_gen,
model_path_gen,
vae_path_gen,
lora_path_gen,
b1_gen,
b2_gen,
s1_gen,
s2_gen,
],
[canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen]
)
undo_button_gen.click(
undo_points,
[original_image_gen, mask_gen],
[input_image_gen, selected_points_gen]
)
run_button_gen.click(
run_drag_gen,
[
n_inference_step_gen,
scheduler_name_gen,
original_image_gen, # the original image generated by the diffusion model
input_image_gen, # image with clicking, masking, etc.
intermediate_latents_gen,
guidance_scale,
mask_gen,
pos_prompt_gen,
neg_prompt_gen,
selected_points_gen,
inversion_strength_gen,
lam_gen,
latent_lr_gen,
n_pix_step_gen,
model_path_gen,
vae_path_gen,
lora_path_gen,
start_step_gen,
start_layer_gen,
b1_gen,
b2_gen,
s1_gen,
s2_gen,
],
[output_image_gen]
)
clear_all_button_gen.click(
clear_all_gen,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas_gen,
input_image_gen,
output_image_gen,
selected_points_gen,
original_image_gen,
mask_gen,
intermediate_latents_gen,
]
)
demo.queue().launch(share=True, server_name="0.0.0.0", debug=True)
name: dragdiff
channels:
- pytorch
- defaults
- nvidia
dependencies:
- python=3.8.5
- pip=22.3.1
- cudatoolkit=11.7
- pip:
- torch==2.0.0
- torchvision==0.15.1
- gradio==3.41.1
- pydantic==2.0.2
- albumentations==1.3.0
- opencv-contrib-python==4.3.0.36
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.5.0
- omegaconf==2.3.0
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.6.0
- transformers==4.27.0
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.16.0
- invisible-watermark>=0.1.5
- streamlit-drawable-canvas==0.8.0
- torchmetrics==0.6.0
- timm==0.6.12
- addict==2.4.0
- yapf==0.32.0
- prettytable==3.6.0
- safetensors==0.3.1
- basicsr==1.4.2
- accelerate==0.17.0
- decord==0.6.0
- diffusers==0.24.0
- moviepy==1.0.3
- opencv_python==4.7.0.68
- Pillow==9.4.0
- scikit_image==0.19.3
- scipy==1.10.1
- tensorboardX==2.6
- tqdm==4.64.1
- numpy==1.24.1
icon.png

68.4 KB

You may put your pretrained model here.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment