Commit 4dce43cc authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 559b8cbf d10441d8
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := tests src utils check_dirs := examples tests src utils
modified_only_fixup: modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
......
...@@ -8,14 +8,23 @@ import PIL.Image ...@@ -8,14 +8,23 @@ import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor from torchvision.transforms import (
Compose,
InterpolationMode,
Lambda,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
def set_seed(seed): def set_seed(seed):
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -30,13 +39,13 @@ model = UNetModel( ...@@ -30,13 +39,13 @@ model = UNetModel(
attn_resolutions=(16,), attn_resolutions=(16,),
ch=128, ch=128,
ch_mult=(1, 2, 2, 2), ch_mult=(1, 2, 2, 2),
dropout=0.1, dropout=0.0,
num_res_blocks=2, num_res_blocks=2,
resamp_with_conv=True, resamp_with_conv=True,
resolution=32 resolution=32,
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 100 num_epochs = 100
batch_size = 64 batch_size = 64
...@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2 ...@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2
augmentations = Compose( augmentations = Compose(
[ [
Resize(32), Resize(32, interpolation=InterpolationMode.BILINEAR),
CenterCrop(32),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(32),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Lambda(lambda x: x * 2 - 1),
] ]
) )
dataset = load_dataset("huggan/pokemon", split="train") dataset = load_dataset("huggan/flowers-102-categories", split="train")
def transforms(examples): def transforms(examples):
...@@ -59,24 +69,24 @@ def transforms(examples): ...@@ -59,24 +69,24 @@ def transforms(examples):
return {"input": images} return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
#lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
# optimizer=optimizer, optimizer=optimizer,
# num_warmup_steps=1000, num_warmup_steps=500,
# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#) )
model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader model, optimizer, train_dataloader, lr_scheduler
) )
for epoch in range(num_epochs): for epoch in range(num_epochs):
model.train() model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba") pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
losses = []
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noisy_images = torch.empty_like(clean_images)
...@@ -101,10 +111,12 @@ for epoch in range(num_epochs): ...@@ -101,10 +111,12 @@ for epoch in range(num_epochs):
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
# lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
loss = loss.detach().item()
losses.append(loss)
pbar.update(1) pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
optimizer.step() optimizer.step()
...@@ -124,5 +136,5 @@ for epoch in range(num_epochs): ...@@ -124,5 +136,5 @@ for epoch in range(num_epochs):
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
# save image # save image
pipeline.save_pretrained("./poke-ddpm") pipeline.save_pretrained("./flowers-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png") image_pil.save(f"./flowers-ddpm/test_{epoch}.png")
...@@ -225,11 +225,11 @@ class ConfigMixin: ...@@ -225,11 +225,11 @@ class ConfigMixin:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
# def __eq__(self, other): def __eq__(self, other):
# return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
# def __repr__(self): def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import torch import torch
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment