Commit 39e48a15 authored by bailuo's avatar bailuo
Browse files

优化精度,并更新readme

parent 7dc08a7d
...@@ -88,7 +88,14 @@ DragBench<br> ...@@ -88,7 +88,14 @@ DragBench<br>
## 训练 ## 训练
推理中有一步LoRA微调,详情见webui。 ```
# LoRA微调
python run_lora_training.py
# 训练drag扩散模型,并输出结果
python run_drag_diffusion.py
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
```
亦或者webui界面。
## 推理 ## 推理
<!-- 下载模型权重: <!-- 下载模型权重:
...@@ -100,6 +107,7 @@ python scripts/download_model.py ...@@ -100,6 +107,7 @@ python scripts/download_model.py
可视化webui推理: 可视化webui推理:
``` ```
python drag_ui.py --listen python drag_ui.py --listen
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
``` ```
<div align=center> <div align=center>
<img src="./doc/webui.png" width=600/> <img src="./doc/webui.png" width=600/>
...@@ -136,15 +144,32 @@ ps:Drag以及LoRA的一些参数自行视情况修改。 ...@@ -136,15 +144,32 @@ ps:Drag以及LoRA的一些参数自行视情况修改。
``` ```
python run_lora_training.py python run_lora_training.py
python run_drag_diffusion.py python run_drag_diffusion.py
python run_eval_similarity.py python run_eval_similarity.py --eval_root /your_dragging_results
python run_eval_point_matching.py --eval_root /your_dragging_results
# ps:上述脚本的一些文件路径自行根据情况修改 # ps:上述脚本的一些文件路径自行根据情况修改
``` ```
| 加速卡 | lpips | clip sim | | 加速卡K100_AI | 1-lpips ↑ | clip sim ↑| mean distance ↓ |
| :-----| :----- | :---- | | :-----| :----- | :---- | :---- |
| K100_AI | 0.115 | 0.977 | | paper | 0.885 | 0.977 | 35.260 |
| 优化后 | 0.869 | 0.975 | 29.824 |
<!-- | 单元格 | 单元格 | 单元格 | --> <!-- | 单元格 | 单元格 | 单元格 | -->
ps:优化两点 1)lora训练部分的rank选择;2)drag部分采用了多层融合。
<div align=center>
<img src="./doc/对比1.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比2.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比3.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比4.jpg" width=600/>
<div >paper结果和优化后的结果</div>
</div>
## 应用场景 ## 应用场景
### 算法类别 ### 算法类别
......
...@@ -96,7 +96,7 @@ def run_drag(source_image, ...@@ -96,7 +96,7 @@ def run_drag(source_image,
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0 args.guidance_scale = 1.0
args.unet_feature_idx = [unet_feature_idx] args.unet_feature_idx = unet_feature_idx
args.r_m = 1 args.r_m = 1
args.r_p = 3 args.r_p = 3
...@@ -207,7 +207,7 @@ if __name__ == '__main__': ...@@ -207,7 +207,7 @@ if __name__ == '__main__':
parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps') parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps')
parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength') parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength')
parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate') parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features') parser.add_argument('--unet_feature_idx', type=int, default=[2,3], help='feature idx of unet features')
args = parser.parse_args() args = parser.parse_args()
all_category = [ all_category = [
......
...@@ -65,6 +65,7 @@ if __name__ == '__main__': ...@@ -65,6 +65,7 @@ if __name__ == '__main__':
all_dist = [] all_dist = []
for cat in all_category: for cat in all_category:
all_dist_ = []
for file_name in os.listdir(os.path.join(original_img_root, cat)): for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store': if file_name == '.DS_Store':
continue continue
...@@ -124,4 +125,8 @@ if __name__ == '__main__': ...@@ -124,4 +125,8 @@ if __name__ == '__main__':
dist = (tp - torch.tensor(max_rc)).float().norm() dist = (tp - torch.tensor(max_rc)).float().norm()
all_dist.append(dist) all_dist.append(dist)
all_dist_.append(dist)
print(cat + ' mean distance: ', torch.tensor(all_dist_).mean().item())
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item()) print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
...@@ -107,9 +107,10 @@ if __name__ == '__main__': ...@@ -107,9 +107,10 @@ if __name__ == '__main__':
all_clip_sim.append(cur_clip_sim.cpu().numpy()) all_clip_sim.append(cur_clip_sim.cpu().numpy())
all_clip_sim_.append(cur_clip_sim.cpu().numpy()) all_clip_sim_.append(cur_clip_sim.cpu().numpy())
print(cat) # print(cat)
print('avg lpips: ', np.mean(all_lpips_)) print(cat + ' avg lpips: ', np.mean(all_lpips_))
print('avg clip sim', np.mean(all_clip_sim_)) print(cat + ' avg clip sim', np.mean(all_clip_sim_))
print(target_root) print(target_root)
print('avg lpips: ', np.mean(all_lpips)) print('avg lpips: ', np.mean(all_lpips))
print('avg 1-lpips: ', 1.0 - np.mean(all_lpips))
print('avg clip sim', np.mean(all_clip_sim)) print('avg clip sim', np.mean(all_clip_sim))
...@@ -36,6 +36,8 @@ import sys ...@@ -36,6 +36,8 @@ import sys
sys.path.insert(0, '../') sys.path.insert(0, '../')
from utils.lora_utils import train_lora from utils.lora_utils import train_lora
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
if __name__ == '__main__': if __name__ == '__main__':
all_category = [ all_category = [
......
...@@ -164,6 +164,9 @@ def train_lora(image, ...@@ -164,6 +164,9 @@ def train_lora(image,
vae.to(device, dtype=torch.float16) vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16) text_encoder.to(device, dtype=torch.float16)
lora_rank_list = [4,4,4,4, 8,8,8,8, 16,16,16,16, 16,16,16,16,16,16, 8,8,8,8,8,8, 4,4,4,4,4,4, 32,32]# down:4+4+4, up:6+6+6, mid:1+1
lora_rank_inx = 0
# Set correct lora layers # Set correct lora layers
unet_lora_parameters = [] unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items(): for attn_processor_name, attn_processor in unet.attn_processors.items():
...@@ -172,6 +175,8 @@ def train_lora(image, ...@@ -172,6 +175,8 @@ def train_lora(image,
for n in attn_processor_name.split(".")[:-1]: for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n) attn_module = getattr(attn_module, n)
lora_rank = lora_rank_list[lora_rank_inx] * 2
# Set the `lora_layer` attribute of the attention-related matrices. # Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer( attn_module.to_q.set_lora_layer(
LoRALinearLayer( LoRALinearLayer(
...@@ -213,19 +218,21 @@ def train_lora(image, ...@@ -213,19 +218,21 @@ def train_lora(image,
LoRALinearLayer( LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features, in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features, out_features=attn_module.add_k_proj.out_features,
rank=args.rank, rank=lora_rank,
) )
) )
attn_module.add_v_proj.set_lora_layer( attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer( LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features, in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features, out_features=attn_module.add_v_proj.out_features,
rank=args.rank, rank=lora_rank,
) )
) )
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
lora_rank_inx = lora_rank_inx + 1
# Optimizer creation # Optimizer creation
params_to_optimize = (unet_lora_parameters) params_to_optimize = (unet_lora_parameters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment