Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
OOTDiffusion_pytorch
Commits
f7d14f4f
Commit
f7d14f4f
authored
May 30, 2024
by
mashun1
Browse files
ootd
parent
8a13970e
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1138397 additions
and
14 deletions
+1138397
-14
.gitignore
.gitignore
+2
-1
README.md
README.md
+5
-1
ootd/inference_ootd_hd.py
ootd/inference_ootd_hd.py
+10
-2
run/images_output/mask.jpg
run/images_output/mask.jpg
+0
-0
run/images_output/out_dc_0.png
run/images_output/out_dc_0.png
+0
-0
run/trace.json
run/trace.json
+1138366
-0
train/data/viton.py
train/data/viton.py
+11
-7
train/main.py
train/main.py
+3
-3
No files found.
.gitignore
View file @
f7d14f4f
...
...
@@ -9,4 +9,5 @@ checkpoints/
train.txt
VITON*
eval_output
eval_ootd.py
\ No newline at end of file
eval_ootd.py
metrics_aigc
\ No newline at end of file
README.md
View file @
f7d14f4f
...
...
@@ -164,7 +164,11 @@ https://hf-mirror.com/openai/clip-vit-large-patch14/tree/main
### 精度
待补充
|ssim|lpips|
|:---:|:---:|
|0.86|0.075|
注意:该精度在size=(512, 384)条件下训练及测试得到,与官方实现(未开源)可能存在不同。
## 应用场景
...
...
ootd/inference_ootd_hd.py
View file @
f7d14f4f
...
...
@@ -30,8 +30,8 @@ sys.path.append(str(OOTD_ROOT))
# VIT_PATH = "../checkpoints/clip-vit-large-patch14"
VIT_PATH
=
os
.
path
.
join
(
OOTD_ROOT
,
"checkpoints/clip-vit-large-patch14"
)
VAE_PATH
=
"../checkpoints/ootd"
UNET_PATH
=
"../checkpoints/ootd/ootd_hd/checkpoint-36000"
#
UNET_PATH = "../train/checkpoints"
#
UNET_PATH = "../checkpoints/ootd/ootd_hd/checkpoint-36000"
UNET_PATH
=
"../train/
ckpts_bak/
checkpoints"
MODEL_PATH
=
"../checkpoints/ootd"
class
OOTDiffusionHD
:
...
...
@@ -123,6 +123,8 @@ class OOTDiffusionHD:
else
:
raise
ValueError
(
"model_type must be
\'
hd
\'
or
\'
dc
\'
!"
)
# start = time.time()
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
images
=
self
.
pipe
(
prompt_embeds
=
prompt_embeds
,
image_garm
=
image_garm
,
image_vton
=
image_vton
,
...
...
@@ -133,5 +135,11 @@ class OOTDiffusionHD:
num_images_per_prompt
=
num_samples
,
generator
=
generator
,
).
images
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# p.export_chrome_trace("trace.json")
# end = time.time()
# print(f"Inference time: {end - start} s")
return
images
run/images_output/mask.jpg
View replaced file @
8a13970e
View file @
f7d14f4f
30.3 KB
|
W:
|
H:
23.7 KB
|
W:
|
H:
2-up
Swipe
Onion skin
run/images_output/out_dc_0.png
View replaced file @
8a13970e
View file @
f7d14f4f
720 KB
|
W:
|
H:
682 KB
|
W:
|
H:
2-up
Swipe
Onion skin
run/trace.json
0 → 100644
View file @
f7d14f4f
This diff is collapsed.
Click to expand it.
train/data/viton.py
View file @
f7d14f4f
...
...
@@ -45,16 +45,19 @@ class VITONDataset(data.Dataset):
self
.
c_names
[
'paired'
]
=
img_names
def
get_parse_agnostic
(
self
,
parse
,
pose_data
):
# parse 语义分割图
# pose_data 姿势信息
parse_array
=
np
.
array
(
parse
)
parse_upper
=
((
parse_array
==
5
).
astype
(
np
.
float32
)
+
(
parse_array
==
6
).
astype
(
np
.
float32
)
+
(
parse_array
==
7
).
astype
(
np
.
float32
))
(
parse_array
==
7
).
astype
(
np
.
float32
))
# 这里是什么形式,应该是一张图且图中仅有这些部位
parse_neck
=
(
parse_array
==
10
).
astype
(
np
.
float32
)
r
=
10
agnostic
=
parse
.
copy
()
# mask arms
# 14表示左臂,15表示右臂
for
parse_id
,
pose_ids
in
[(
14
,
[
2
,
5
,
6
,
7
]),
(
15
,
[
5
,
2
,
3
,
4
])]:
mask_arm
=
Image
.
new
(
'L'
,
(
self
.
load_width
,
self
.
load_height
),
'black'
)
mask_arm_draw
=
ImageDraw
.
Draw
(
mask_arm
)
...
...
@@ -129,18 +132,18 @@ class VITONDataset(data.Dataset):
def
__getitem__
(
self
,
index
):
img_name
=
self
.
img_names
[
index
]
c_name
=
{}
c
=
{}
cm
=
{}
c
=
{}
# 衣物
cm
=
{}
# 衣物的mask
for
key
in
self
.
c_names
:
c_name
[
key
]
=
self
.
c_names
[
key
][
index
]
c
[
key
]
=
Image
.
open
(
osp
.
join
(
self
.
data_path
,
'cloth'
,
c_name
[
key
])).
convert
(
'RGB'
)
c
[
key
]
=
transforms
.
Resize
(
self
.
load_width
,
interpolation
=
2
)(
c
[
key
])
c
[
key
]
=
Image
.
open
(
osp
.
join
(
self
.
data_path
,
'cloth'
,
c_name
[
key
])).
convert
(
'RGB'
)
# 读取衣服图像
c
[
key
]
=
transforms
.
Resize
(
self
.
load_width
,
interpolation
=
2
)(
c
[
key
])
# 修改宽度
cm
[
key
]
=
Image
.
open
(
osp
.
join
(
self
.
data_path
,
'cloth-mask'
,
c_name
[
key
]))
cm
[
key
]
=
transforms
.
Resize
(
self
.
load_width
,
interpolation
=
0
)(
cm
[
key
])
c
[
key
]
=
self
.
transform
(
c
[
key
])
# [-1,1]
cm_array
=
np
.
array
(
cm
[
key
])
cm_array
=
(
cm_array
>=
128
).
astype
(
np
.
float32
)
cm_array
=
(
cm_array
>=
128
).
astype
(
np
.
float32
)
# 二值化
cm
[
key
]
=
torch
.
from_numpy
(
cm_array
)
# [0,1]
cm
[
key
].
unsqueeze_
(
0
)
...
...
@@ -157,7 +160,7 @@ class VITONDataset(data.Dataset):
pose_data
=
np
.
array
(
pose_data
)
pose_data
=
pose_data
.
reshape
((
-
1
,
3
))[:,
:
2
]
# load parsing image
# load parsing image
语义分割图
parse_name
=
img_name
.
replace
(
'.jpg'
,
'.png'
)
parse
=
Image
.
open
(
osp
.
join
(
self
.
data_path
,
'image-parse-v3'
,
parse_name
))
parse
=
transforms
.
Resize
(
self
.
load_width
,
interpolation
=
0
)(
parse
)
...
...
@@ -179,6 +182,7 @@ class VITONDataset(data.Dataset):
11
:
[
'socks'
,
[
8
]],
12
:
[
'noise'
,
[
3
,
11
]]
}
# 不同通道表示不同类别
parse_agnostic_map
=
torch
.
zeros
(
20
,
self
.
load_height
,
self
.
load_width
,
dtype
=
torch
.
float
)
parse_agnostic_map
.
scatter_
(
0
,
parse_agnostic
,
1.0
)
new_parse_agnostic_map
=
torch
.
zeros
(
self
.
semantic_nc
,
self
.
load_height
,
self
.
load_width
,
dtype
=
torch
.
float
)
...
...
train/main.py
View file @
f7d14f4f
...
...
@@ -59,15 +59,15 @@ def main():
args
.
lr_scheduler
)
trainer
=
L
.
Trainer
(
max_epochs
=
5
0
,
max_epochs
=
10
0
,
accelerator
=
'auto'
,
log_every_n_steps
=
1
,
callbacks
=
[
ModelCheckpoint
(
every_n_train_steps
=
6
000
,
save_top_k
=-
1
,
save_last
=
True
)],
callbacks
=
[
ModelCheckpoint
(
every_n_train_steps
=
5
000
,
save_top_k
=-
1
,
save_last
=
True
)],
precision
=
"16-mixed"
,
accumulate_grad_batches
=
32
,
)
trainer
.
fit
(
model
,
dm
,
ckpt_path
=
"lightning_logs/version_
6
/checkpoints/
last
.ckpt"
)
trainer
.
fit
(
model
,
dm
,
ckpt_path
=
"lightning_logs/version_
11
/checkpoints/
epoch=54-step=10000
.ckpt"
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment