Commit f7d14f4f authored by mashun1's avatar mashun1
Browse files

ootd

parent 8a13970e
......@@ -9,4 +9,5 @@ checkpoints/
train.txt
VITON*
eval_output
eval_ootd.py
\ No newline at end of file
eval_ootd.py
metrics_aigc
\ No newline at end of file
......@@ -164,7 +164,11 @@ https://hf-mirror.com/openai/clip-vit-large-patch14/tree/main
### 精度
待补充
|ssim|lpips|
|:---:|:---:|
|0.86|0.075|
注意:该精度在size=(512, 384)条件下训练及测试得到,与官方实现(未开源)可能存在不同。
## 应用场景
......
......@@ -30,8 +30,8 @@ sys.path.append(str(OOTD_ROOT))
# VIT_PATH = "../checkpoints/clip-vit-large-patch14"
VIT_PATH = os.path.join(OOTD_ROOT, "checkpoints/clip-vit-large-patch14")
VAE_PATH = "../checkpoints/ootd"
UNET_PATH = "../checkpoints/ootd/ootd_hd/checkpoint-36000"
# UNET_PATH = "../train/checkpoints"
# UNET_PATH = "../checkpoints/ootd/ootd_hd/checkpoint-36000"
UNET_PATH = "../train/ckpts_bak/checkpoints"
MODEL_PATH = "../checkpoints/ootd"
class OOTDiffusionHD:
......@@ -123,6 +123,8 @@ class OOTDiffusionHD:
else:
raise ValueError("model_type must be \'hd\' or \'dc\'!")
# start = time.time()
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
images = self.pipe(prompt_embeds=prompt_embeds,
image_garm=image_garm,
image_vton=image_vton,
......@@ -133,5 +135,11 @@ class OOTDiffusionHD:
num_images_per_prompt=num_samples,
generator=generator,
).images
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# p.export_chrome_trace("trace.json")
# end = time.time()
# print(f"Inference time: {end - start} s")
return images
run/images_output/mask.jpg

30.3 KB | W: | H:

run/images_output/mask.jpg

23.7 KB | W: | H:

run/images_output/mask.jpg
run/images_output/mask.jpg
run/images_output/mask.jpg
run/images_output/mask.jpg
  • 2-up
  • Swipe
  • Onion skin
run/images_output/out_dc_0.png

720 KB | W: | H:

run/images_output/out_dc_0.png

682 KB | W: | H:

run/images_output/out_dc_0.png
run/images_output/out_dc_0.png
run/images_output/out_dc_0.png
run/images_output/out_dc_0.png
  • 2-up
  • Swipe
  • Onion skin
This diff is collapsed.
......@@ -45,16 +45,19 @@ class VITONDataset(data.Dataset):
self.c_names['paired'] = img_names
def get_parse_agnostic(self, parse, pose_data):
# parse 语义分割图
# pose_data 姿势信息
parse_array = np.array(parse)
parse_upper = ((parse_array == 5).astype(np.float32) +
(parse_array == 6).astype(np.float32) +
(parse_array == 7).astype(np.float32))
(parse_array == 7).astype(np.float32)) # 这里是什么形式,应该是一张图且图中仅有这些部位
parse_neck = (parse_array == 10).astype(np.float32)
r = 10
agnostic = parse.copy()
# mask arms
# 14表示左臂,15表示右臂
for parse_id, pose_ids in [(14, [2, 5, 6, 7]), (15, [5, 2, 3, 4])]:
mask_arm = Image.new('L', (self.load_width, self.load_height), 'black')
mask_arm_draw = ImageDraw.Draw(mask_arm)
......@@ -129,18 +132,18 @@ class VITONDataset(data.Dataset):
def __getitem__(self, index):
img_name = self.img_names[index]
c_name = {}
c = {}
cm = {}
c = {} # 衣物
cm = {} # 衣物的mask
for key in self.c_names:
c_name[key] = self.c_names[key][index]
c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB')
c[key] = transforms.Resize(self.load_width, interpolation=2)(c[key])
c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB') # 读取衣服图像
c[key] = transforms.Resize(self.load_width, interpolation=2)(c[key]) # 修改宽度
cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key]))
cm[key] = transforms.Resize(self.load_width, interpolation=0)(cm[key])
c[key] = self.transform(c[key]) # [-1,1]
cm_array = np.array(cm[key])
cm_array = (cm_array >= 128).astype(np.float32)
cm_array = (cm_array >= 128).astype(np.float32) # 二值化
cm[key] = torch.from_numpy(cm_array) # [0,1]
cm[key].unsqueeze_(0)
......@@ -157,7 +160,7 @@ class VITONDataset(data.Dataset):
pose_data = np.array(pose_data)
pose_data = pose_data.reshape((-1, 3))[:, :2]
# load parsing image
# load parsing image 语义分割图
parse_name = img_name.replace('.jpg', '.png')
parse = Image.open(osp.join(self.data_path, 'image-parse-v3', parse_name))
parse = transforms.Resize(self.load_width, interpolation=0)(parse)
......@@ -179,6 +182,7 @@ class VITONDataset(data.Dataset):
11: ['socks', [8]],
12: ['noise', [3, 11]]
}
# 不同通道表示不同类别
parse_agnostic_map = torch.zeros(20, self.load_height, self.load_width, dtype=torch.float)
parse_agnostic_map.scatter_(0, parse_agnostic, 1.0)
new_parse_agnostic_map = torch.zeros(self.semantic_nc, self.load_height, self.load_width, dtype=torch.float)
......
......@@ -59,15 +59,15 @@ def main():
args.lr_scheduler)
trainer = L.Trainer(
max_epochs=50,
max_epochs=100,
accelerator='auto',
log_every_n_steps=1,
callbacks=[ModelCheckpoint(every_n_train_steps=6000, save_top_k=-1, save_last=True)],
callbacks=[ModelCheckpoint(every_n_train_steps=5000, save_top_k=-1, save_last=True)],
precision="16-mixed",
accumulate_grad_batches=32,
)
trainer.fit(model, dm, ckpt_path="lightning_logs/version_6/checkpoints/last.ckpt")
trainer.fit(model, dm, ckpt_path="lightning_logs/version_11/checkpoints/epoch=54-step=10000.ckpt")
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment