"docs/source/vscode:/vscode.git/clone" did not exist on "b4cbbd5ed2691d51100aa2991d7ad3d82e50cd70"
Commit de22d4cd authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make sure config attributes are only accessed via the config in schedulers

parent 8c1f5197
...@@ -258,10 +258,6 @@ class ConfigMixin: ...@@ -258,10 +258,6 @@ class ConfigMixin:
class FrozenDict(OrderedDict): class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# remove `None`
args = (a for a in args if a is not None)
kwargs = {k: v for k, v in kwargs if v is not None}
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
for key, value in self.items(): for key, value in self.items():
......
# Pipelines
- Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box
- Pipelines should stay as close as possible to their original implementation
- Pipelines can include components of other library, such as text-encoders.
## API
TODO(Patrick, Anton, Suraj)
## Examples
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Glide for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
...@@ -46,7 +46,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -46,7 +46,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise( image = self.noise_scheduler.sample_noise(
......
...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline): ...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values timestep_values = self.noise_scheduler.get_timestep_values()
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
......
...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline): ...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
......
...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = torch.randn(
......
...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
timesteps=timesteps, timesteps=timesteps,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2": if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -37,10 +37,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,10 +37,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -81,6 +81,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -81,6 +81,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# ) # )
# self.alphas = 1.0 - self.betas # self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -96,7 +98,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,7 +98,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def get_orig_t(self, t, num_inference_steps): def get_orig_t(self, t, num_inference_steps):
if t < 0: if t < 0:
return -1 return -1
return self.timesteps // num_inference_steps * t return self.config.timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps): def get_variance(self, t, num_inference_steps):
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.get_orig_t(t, num_inference_steps)
...@@ -137,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
...@@ -158,4 +160,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,4 +160,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -43,10 +43,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -43,10 +43,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type=variance_type, variance_type=variance_type,
clip_sample=clip_sample, clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
self.variance_type = variance_type
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
...@@ -83,6 +79,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -83,6 +79,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# #
# #
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -105,9 +103,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -105,9 +103,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.variance_type == "fixed_small": if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large": elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t) variance = self.get_beta(t)
return variance return variance
...@@ -124,7 +122,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -124,7 +122,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
...@@ -145,4 +143,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,4 +143,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return noisy_sample return noisy_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
) )
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep): def sample_noise(self, timestep):
...@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
return xt return xt
def __len__(self): def __len__(self):
return self.timesteps return len(self.config.timesteps)
...@@ -35,7 +35,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -35,7 +35,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -82,10 +81,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -82,10 +81,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.warmup_time_steps: if num_inference_steps in self.warmup_time_steps:
return self.warmup_time_steps[num_inference_steps] return self.warmup_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
) )
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
...@@ -95,7 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -95,7 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.time_steps: if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
...@@ -148,4 +147,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -148,4 +147,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return x_next return x_next
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment