Commit 418888a5 authored by anton-l's avatar anton-l
Browse files

Pokemon DDPM training

parent 55d29ab7
...@@ -8,14 +8,14 @@ import PIL.Image ...@@ -8,14 +8,14 @@ import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
def set_seed(seed): def set_seed(seed):
torch.backends.cudnn.deterministic = True #torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #torch.backends.cudnn.benchmark = False
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -30,13 +30,13 @@ model = UNetModel( ...@@ -30,13 +30,13 @@ model = UNetModel(
attn_resolutions=(16,), attn_resolutions=(16,),
ch=128, ch=128,
ch_mult=(1, 2, 2, 2), ch_mult=(1, 2, 2, 2),
dropout=0.1, dropout=0.0,
num_res_blocks=2, num_res_blocks=2,
resamp_with_conv=True, resamp_with_conv=True,
resolution=32 resolution=32
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 100 num_epochs = 100
batch_size = 64 batch_size = 64
...@@ -44,9 +44,10 @@ gradient_accumulation_steps = 2 ...@@ -44,9 +44,10 @@ gradient_accumulation_steps = 2
augmentations = Compose( augmentations = Compose(
[ [
Resize(32),
CenterCrop(32),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1),
Resize(32, interpolation=InterpolationMode.BILINEAR),
CenterCrop(32),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Lambda(lambda x: x * 2 - 1),
] ]
...@@ -59,24 +60,24 @@ def transforms(examples): ...@@ -59,24 +60,24 @@ def transforms(examples):
return {"input": images} return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
#lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
# optimizer=optimizer, optimizer=optimizer,
# num_warmup_steps=1000, num_warmup_steps=500,
# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#) )
model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader model, optimizer, train_dataloader, lr_scheduler
) )
for epoch in range(num_epochs): for epoch in range(num_epochs):
model.train() model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba") pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
losses = []
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noisy_images = torch.empty_like(clean_images)
...@@ -101,10 +102,12 @@ for epoch in range(num_epochs): ...@@ -101,10 +102,12 @@ for epoch in range(num_epochs):
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
# lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
loss = loss.detach().item()
losses.append(loss)
pbar.update(1) pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment