Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
OOTDiffusion_pytorch
Commits
c50c08d9
Commit
c50c08d9
authored
May 22, 2024
by
mashun1
Browse files
ootd
parent
fb08b1e6
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
14 deletions
+6
-14
train/model/ootd_ori.py
train/model/ootd_ori.py
+6
-14
No files found.
train/model/ootd_ori.py
View file @
c50c08d9
...
...
@@ -65,7 +65,6 @@ class OOTDiffusion(L.LightningModule):
self
.
power
=
power
self
.
init_models
()
self
.
automatic_optimization
=
False
def
init_models
(
self
):
self
.
vae
=
AutoencoderKL
.
from_pretrained
(
...
...
@@ -76,16 +75,18 @@ class OOTDiffusion(L.LightningModule):
self
.
unet_garm
=
UNetGarm2DConditionModel
.
from_pretrained
(
self
.
unet_path
,
subfolder
=
"unet_garm"
,
subfolder
=
"unet"
,
# subfolder="unet_garm",
torcch_dtype
=
torch
.
float32
,
use_safetensors
=
True
#
use_safetensors=True
)
self
.
unet_vton
=
UNetVton2DConditionModel
.
from_pretrained
(
self
.
unet_path
,
subfolder
=
"unet_vton"
,
# subfolder="unet_vton",
subfolder
=
"unet"
,
torch_dtype
=
torch
.
float32
,
use_safetensors
=
True
#
use_safetensors=True
)
# 修改模型通道数,适应输入数据
...
...
@@ -263,18 +264,9 @@ class OOTDiffusion(L.LightningModule):
return
loss
def
training_step
(
self
,
batch
):
opt
=
self
.
optimizers
()
sch
=
self
.
lr_schedulers
()
loss
=
self
(
batch
)
self
.
manual_backward
(
loss
)
self
.
log
(
'loss'
,
loss
,
prog_bar
=
True
)
opt
.
step
()
sch
.
step
()
opt
.
zero_grad
()
return
loss
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment