Commit 39e48a15 authored by bailuo's avatar bailuo
Browse files

优化精度,并更新readme

parent 7dc08a7d
......@@ -88,7 +88,14 @@ DragBench<br>
## 训练
推理中有一步LoRA微调,详情见webui。
```
# LoRA微调
python run_lora_training.py
# 训练drag扩散模型,并输出结果
python run_drag_diffusion.py
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
```
亦或者webui界面。
## 推理
<!-- 下载模型权重:
......@@ -100,6 +107,7 @@ python scripts/download_model.py
可视化webui推理:
```
python drag_ui.py --listen
# 如果出现huggingface访问不通,请执行 `export HF_ENDPOINT=https://hf-mirror.com`
```
<div align=center>
<img src="./doc/webui.png" width=600/>
......@@ -136,15 +144,32 @@ ps:Drag以及LoRA的一些参数自行视情况修改。
```
python run_lora_training.py
python run_drag_diffusion.py
python run_eval_similarity.py
python run_eval_similarity.py --eval_root /your_dragging_results
python run_eval_point_matching.py --eval_root /your_dragging_results
# ps:上述脚本的一些文件路径自行根据情况修改
```
| 加速卡 | lpips | clip sim |
| :-----| :----- | :---- |
| K100_AI | 0.115 | 0.977 |
| 加速卡K100_AI | 1-lpips ↑ | clip sim ↑| mean distance ↓ |
| :-----| :----- | :---- | :---- |
| paper | 0.885 | 0.977 | 35.260 |
| 优化后 | 0.869 | 0.975 | 29.824 |
<!-- | 单元格 | 单元格 | 单元格 | -->
ps:优化两点 1)lora训练部分的rank选择;2)drag部分采用了多层融合。
<div align=center>
<img src="./doc/对比1.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比2.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比3.jpg" width=600/>
</div>
<div align=center>
<img src="./doc/对比4.jpg" width=600/>
<div >paper结果和优化后的结果</div>
</div>
## 应用场景
### 算法类别
......
......@@ -96,7 +96,7 @@ def run_drag(source_image,
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0
args.unet_feature_idx = [unet_feature_idx]
args.unet_feature_idx = unet_feature_idx
args.r_m = 1
args.r_p = 3
......@@ -207,7 +207,7 @@ if __name__ == '__main__':
parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps')
parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength')
parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
parser.add_argument('--unet_feature_idx', type=int, default=[2,3], help='feature idx of unet features')
args = parser.parse_args()
all_category = [
......
......@@ -65,6 +65,7 @@ if __name__ == '__main__':
all_dist = []
for cat in all_category:
all_dist_ = []
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
......@@ -124,4 +125,8 @@ if __name__ == '__main__':
dist = (tp - torch.tensor(max_rc)).float().norm()
all_dist.append(dist)
all_dist_.append(dist)
print(cat + ' mean distance: ', torch.tensor(all_dist_).mean().item())
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
......@@ -107,9 +107,10 @@ if __name__ == '__main__':
all_clip_sim.append(cur_clip_sim.cpu().numpy())
all_clip_sim_.append(cur_clip_sim.cpu().numpy())
print(cat)
print('avg lpips: ', np.mean(all_lpips_))
print('avg clip sim', np.mean(all_clip_sim_))
# print(cat)
print(cat + ' avg lpips: ', np.mean(all_lpips_))
print(cat + ' avg clip sim', np.mean(all_clip_sim_))
print(target_root)
print('avg lpips: ', np.mean(all_lpips))
print('avg 1-lpips: ', 1.0 - np.mean(all_lpips))
print('avg clip sim', np.mean(all_clip_sim))
......@@ -36,6 +36,8 @@ import sys
sys.path.insert(0, '../')
from utils.lora_utils import train_lora
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
if __name__ == '__main__':
all_category = [
......
......@@ -164,6 +164,9 @@ def train_lora(image,
vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16)
lora_rank_list = [4,4,4,4, 8,8,8,8, 16,16,16,16, 16,16,16,16,16,16, 8,8,8,8,8,8, 4,4,4,4,4,4, 32,32]# down:4+4+4, up:6+6+6, mid:1+1
lora_rank_inx = 0
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
......@@ -171,6 +174,8 @@ def train_lora(image,
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
lora_rank = lora_rank_list[lora_rank_inx] * 2
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
......@@ -213,18 +218,20 @@ def train_lora(image,
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
rank=lora_rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
rank=lora_rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
lora_rank_inx = lora_rank_inx + 1
# Optimizer creation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment