Commit c50c08d9 authored by mashun1's avatar mashun1
Browse files

ootd

parent fb08b1e6
...@@ -65,7 +65,6 @@ class OOTDiffusion(L.LightningModule): ...@@ -65,7 +65,6 @@ class OOTDiffusion(L.LightningModule):
self.power = power self.power = power
self.init_models() self.init_models()
self.automatic_optimization = False
def init_models(self): def init_models(self):
self.vae = AutoencoderKL.from_pretrained( self.vae = AutoencoderKL.from_pretrained(
...@@ -76,16 +75,18 @@ class OOTDiffusion(L.LightningModule): ...@@ -76,16 +75,18 @@ class OOTDiffusion(L.LightningModule):
self.unet_garm = UNetGarm2DConditionModel.from_pretrained( self.unet_garm = UNetGarm2DConditionModel.from_pretrained(
self.unet_path, self.unet_path,
subfolder="unet_garm", subfolder="unet",
# subfolder="unet_garm",
torcch_dtype=torch.float32, torcch_dtype=torch.float32,
use_safetensors=True # use_safetensors=True
) )
self.unet_vton = UNetVton2DConditionModel.from_pretrained( self.unet_vton = UNetVton2DConditionModel.from_pretrained(
self.unet_path, self.unet_path,
subfolder="unet_vton", # subfolder="unet_vton",
subfolder="unet",
torch_dtype=torch.float32, torch_dtype=torch.float32,
use_safetensors=True # use_safetensors=True
) )
# 修改模型通道数,适应输入数据 # 修改模型通道数,适应输入数据
...@@ -263,18 +264,9 @@ class OOTDiffusion(L.LightningModule): ...@@ -263,18 +264,9 @@ class OOTDiffusion(L.LightningModule):
return loss return loss
def training_step(self, batch): def training_step(self, batch):
opt = self.optimizers()
sch = self.lr_schedulers()
loss = self(batch) loss = self(batch)
self.manual_backward(loss)
self.log('loss', loss, prog_bar=True) self.log('loss', loss, prog_bar=True)
opt.step()
sch.step()
opt.zero_grad()
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment