Commit 77451b3b authored by anton-l's avatar anton-l
Browse files

tune ddpm training

parent a82d2592
......@@ -6,7 +6,7 @@ __version__ = "0.0.3"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM
......
......@@ -17,5 +17,5 @@
# limitations under the License.
from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel
......@@ -63,8 +63,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
......@@ -141,7 +139,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image
def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5
noisy_image = sqrt_alpha_prod * original_image + sqrt_one_minus_alpha_prod * noise
return noisy_image
def __len__(self):
......
......@@ -24,20 +24,28 @@ def set_seed(seed):
set_seed(0)
accelerator = Accelerator(mixed_precision="fp16")
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
accelerator = Accelerator()
model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 2, 2),
dropout=0.1,
num_res_blocks=2,
resamp_with_conv=True,
resolution=32
)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
num_epochs = 100
batch_size = 8
gradient_accumulation_steps = 8
batch_size = 64
gradient_accumulation_steps = 2
augmentations = Compose(
[
Resize(64),
CenterCrop(64),
Resize(32),
CenterCrop(32),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
......@@ -55,14 +63,14 @@ dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
#lr_scheduler = get_linear_schedule_with_warmup(
# optimizer=optimizer,
# num_warmup_steps=1000,
# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
for epoch in range(num_epochs):
......@@ -72,24 +80,28 @@ for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
noise_samples = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
noise = torch.randn((3, 32, 32)).to(clean_images.device)
noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
# predict the noise
loss = F.l1_loss(output, noise_samples)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
# lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment