Commit 1cf7933e authored by anton-l's avatar anton-l
Browse files

Framework-agnostic timestep broadcasting

parent 0e13d329
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
import PIL.Image import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
...@@ -71,7 +71,7 @@ def main(args): ...@@ -71,7 +71,7 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler model, optimizer, train_dataloader, lr_scheduler
) )
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4) ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub: if args.push_to_hub:
repo = init_git_repo(args, at_init=True) repo = init_git_repo(args, at_init=True)
...@@ -133,7 +133,7 @@ def main(args): ...@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection # Generate a sample image for visual inspection
if accelerator.is_main_process: if accelerator.is_main_process:
with torch.no_grad(): with torch.no_grad():
pipeline = DDPM( pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
) )
...@@ -172,6 +172,9 @@ if __name__ == "__main__": ...@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3/4)
parser.add_argument("--ema_max_decay", type=float, default=0.999)
parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None)
......
...@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
if timesteps.dim() != 1:
raise ValueError("`timesteps` must be a 1D tensor")
device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import numpy as np import numpy as np
import torch import torch
from typing import Union
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
...@@ -50,3 +52,29 @@ class SchedulerMixin: ...@@ -50,3 +52,29 @@ class SchedulerMixin:
return torch.log(tensor) return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(
self,
values: Union[np.ndarray, torch.Tensor],
broadcast_array: Union[np.ndarray, torch.Tensor]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment