Commit ca72c1f8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 059a6e9d 55d29ab7
...@@ -6,7 +6,7 @@ __version__ = "0.0.3" ...@@ -6,7 +6,7 @@ __version__ = "0.0.3"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
......
...@@ -17,5 +17,5 @@ ...@@ -17,5 +17,5 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
...@@ -63,8 +63,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -63,8 +63,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0) self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
...@@ -141,7 +139,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -141,7 +139,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image return pred_prev_image
def forward_step(self, original_image, noise, t): def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5
noisy_image = sqrt_alpha_prod * original_image + sqrt_one_minus_alpha_prod * noise
return noisy_image return noisy_image
def __len__(self): def __len__(self):
......
...@@ -24,20 +24,28 @@ def set_seed(seed): ...@@ -24,20 +24,28 @@ def set_seed(seed):
set_seed(0) set_seed(0)
accelerator = Accelerator(mixed_precision="fp16") accelerator = Accelerator()
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64) model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 2, 2),
dropout=0.1,
num_res_blocks=2,
resamp_with_conv=True,
resolution=32
)
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
num_epochs = 100 num_epochs = 100
batch_size = 8 batch_size = 64
gradient_accumulation_steps = 8 gradient_accumulation_steps = 2
augmentations = Compose( augmentations = Compose(
[ [
Resize(64), Resize(32),
CenterCrop(64), CenterCrop(32),
RandomHorizontalFlip(), RandomHorizontalFlip(),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Lambda(lambda x: x * 2 - 1),
...@@ -55,14 +63,14 @@ dataset = dataset.shuffle(seed=0) ...@@ -55,14 +63,14 @@ dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup( #lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer, # optimizer=optimizer,
num_warmup_steps=1000, # num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, # num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
) #)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler model, optimizer, train_dataloader
) )
for epoch in range(num_epochs): for epoch in range(num_epochs):
...@@ -72,24 +80,28 @@ for epoch in range(num_epochs): ...@@ -72,24 +80,28 @@ for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noisy_images = torch.empty_like(clean_images)
noise_samples = torch.empty_like(clean_images)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz): for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device) noise = torch.randn((3, 32, 32)).to(clean_images.device)
noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0: if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model): with accelerator.no_sync(model):
output = model(noisy_images, timesteps) output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images) # predict the noise
loss = F.l1_loss(output, noise_samples)
accelerator.backward(loss) accelerator.backward(loss)
else: else:
output = model(noisy_images, timesteps) output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images) loss = F.l1_loss(output, clean_images)
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
lr_scheduler.step() # lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
pbar.update(1) pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment