Commit 7fe05bb3 authored by anton-l's avatar anton-l
Browse files

Bugfixes for the training example

parent 1fd02631
import random import os
import numpy as np
import torch import torch
import PIL.Image
import argparse
import torch.nn.functional as F import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
...@@ -14,7 +14,6 @@ from torchvision.transforms import ( ...@@ -14,7 +14,6 @@ from torchvision.transforms import (
Lambda, Lambda,
RandomCrop, RandomCrop,
RandomHorizontalFlip, RandomHorizontalFlip,
RandomVerticalFlip,
Resize, Resize,
ToTensor, ToTensor,
) )
...@@ -22,119 +21,126 @@ from tqdm.auto import tqdm ...@@ -22,119 +21,126 @@ from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
def set_seed(seed): def main(args):
# torch.backends.cudnn.deterministic = True accelerator = Accelerator(mixed_precision=args.mixed_precision)
# torch.backends.cudnn.benchmark = False
torch.manual_seed(seed) model = UNetModel(
torch.cuda.manual_seed_all(seed) attn_resolutions=(16,),
np.random.seed(seed) ch=128,
random.seed(seed) ch_mult=(1, 2, 4, 8),
dropout=0.0,
num_res_blocks=2,
set_seed(0) resamp_with_conv=True,
resolution=64,
accelerator = Accelerator() )
noise_scheduler = DDPMScheduler(timesteps=1000)
model = UNetModel( optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
attn_resolutions=(16,),
ch=128, num_epochs = 100
ch_mult=(1, 2, 2, 2), batch_size = 16
dropout=0.0, gradient_accumulation_steps = 1
num_res_blocks=2,
resamp_with_conv=True, augmentations = Compose(
resolution=32, [
) Resize(64, interpolation=InterpolationMode.BILINEAR),
noise_scheduler = DDPMScheduler(timesteps=1000) RandomCrop(64),
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) RandomHorizontalFlip(),
ToTensor(),
num_epochs = 100 Lambda(lambda x: x * 2 - 1),
batch_size = 64 ]
gradient_accumulation_steps = 2 )
dataset = load_dataset("huggan/pokemon", split="train")
augmentations = Compose(
[ def transforms(examples):
Resize(32, interpolation=InterpolationMode.BILINEAR), images = [augmentations(image.convert("RGB")) for image in examples["image"]]
RandomHorizontalFlip(), return {"input": images}
RandomVerticalFlip(),
RandomCrop(32), dataset.set_transform(transforms)
ToTensor(), train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
Lambda(lambda x: x * 2 - 1),
] lr_scheduler = get_linear_schedule_with_warmup(
) optimizer=optimizer,
dataset = load_dataset("huggan/flowers-102-categories", split="train") num_warmup_steps=500,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]] model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
return {"input": images} model, optimizer, train_dataloader, lr_scheduler
)
dataset.set_transform(transforms) for epoch in range(num_epochs):
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
lr_scheduler = get_linear_schedule_with_warmup( pbar.set_description(f"Epoch {epoch}")
optimizer=optimizer, for step, batch in enumerate(train_dataloader):
num_warmup_steps=500, clean_images = batch["input"]
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, noisy_images = torch.empty_like(clean_images)
) noise_samples = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
) for idx in range(bsz):
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device)
for epoch in range(num_epochs): noise_samples[idx] = noise
model.train() noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}") if step % gradient_accumulation_steps != 0:
losses = [] with accelerator.no_sync(model):
for step, batch in enumerate(train_dataloader): output = model(noisy_images, timesteps)
clean_images = batch["input"] # predict the noise
noisy_images = torch.empty_like(clean_images) loss = F.mse_loss(output, noise_samples)
noise_samples = torch.empty_like(clean_images) accelerator.backward(loss)
bsz = clean_images.shape[0] else:
output = model(noisy_images, timesteps)
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() loss = F.mse_loss(output, noise_samples)
for idx in range(bsz): accelerator.backward(loss)
noise = torch.randn((3, 32, 32)).to(clean_images.device) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
noise_samples[idx] = noise optimizer.step()
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) lr_scheduler.step()
optimizer.zero_grad()
if step % gradient_accumulation_steps == 0: pbar.update(1)
with accelerator.no_sync(model): pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
output = model(noisy_images, timesteps)
# predict the noise optimizer.step()
loss = F.l1_loss(output, noise_samples)
accelerator.backward(loss) torch.distributed.barrier()
else: if args.local_rank in [-1, 0]:
output = model(noisy_images, timesteps) model.eval()
loss = F.l1_loss(output, clean_images) with torch.no_grad():
accelerator.backward(loss) pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) generator = torch.Generator()
optimizer.step() generator = generator.manual_seed(0)
lr_scheduler.step() # run pipeline in inference (sample random noise and denoise)
optimizer.zero_grad() image = pipeline(generator=generator)
loss = loss.detach().item()
losses.append(loss) # process image to PIL
pbar.update(1) image_processed = image.cpu().permute(0, 2, 3, 1)
pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"]) image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
optimizer.step() image_pil = PIL.Image.fromarray(image_processed[0])
# eval # save image
model.eval() pipeline.save_pretrained("./pokemon-ddpm")
with torch.no_grad(): image_pil.save(f"./pokemon-ddpm/test_{epoch}.png")
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler) torch.distributed.barrier()
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) if __name__ == "__main__":
image = pipeline(generator=generator) parser = argparse.ArgumentParser(description="Simple example of training script.")
parser.add_argument("--local_rank", type=int)
# process image to PIL parser.add_argument(
image_processed = image.cpu().permute(0, 2, 3, 1) "--mixed_precision",
image_processed = (image_processed + 1.0) * 127.5 type=str,
image_processed = image_processed.type(torch.uint8).numpy() default="no",
image_pil = PIL.Image.fromarray(image_processed[0]) choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose"
# save image "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
pipeline.save_pretrained("./flowers-ddpm") "and an Nvidia Ampere GPU.",
image_pil.save(f"./flowers-ddpm/test_{epoch}.png") )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment