Commit de22d4cd authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make sure config attributes are only accessed via the config in schedulers

parent 8c1f5197
......@@ -258,10 +258,6 @@ class ConfigMixin:
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
# remove `None`
args = (a for a in args if a is not None)
kwargs = {k: v for k, v in kwargs if v is not None}
super().__init__(*args, **kwargs)
for key, value in self.items():
......
# Pipelines
- Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box
- Pipelines should stay as close as possible to their original implementation
- Pipelines can include components of other library, such as text-encoders.
## API
TODO(Patrick, Anton, Suraj)
## Examples
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Glide for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
......@@ -46,7 +46,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise(
......
......@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
# Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values
timestep_values = self.noise_scheduler.get_timestep_values()
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
......
......@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
......
......@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
......
......@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
timesteps=timesteps,
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
......@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.timesteps
return self.config.timesteps
......@@ -37,10 +37,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
......@@ -81,6 +81,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# )
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step):
return self.alphas[time_step]
......@@ -96,7 +98,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def get_orig_t(self, t, num_inference_steps):
if t < 0:
return -1
return self.timesteps // num_inference_steps * t
return self.config.timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps):
orig_t = self.get_orig_t(t, num_inference_steps)
......@@ -137,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.clip_sample:
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
......@@ -158,4 +160,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample
def __len__(self):
return self.timesteps
return self.config.timesteps
......@@ -43,10 +43,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type=variance_type,
clip_sample=clip_sample,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
self.variance_type = variance_type
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
......@@ -83,6 +79,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step):
return self.alphas[time_step]
......@@ -105,9 +103,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability
if self.variance_type == "fixed_small":
if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large":
elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t)
return variance
......@@ -124,7 +122,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0"
if self.clip_sample:
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
......@@ -145,4 +143,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return noisy_sample
def __len__(self):
return self.timesteps
return self.config.timesteps
......@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start,
beta_end=beta_end,
)
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep):
......@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
return xt
def __len__(self):
return self.timesteps
return len(self.config.timesteps)
......@@ -35,7 +35,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end=beta_end,
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
......@@ -82,10 +81,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.warmup_time_steps:
return self.warmup_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
......@@ -95,7 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps]
......@@ -148,4 +147,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return x_next
def __len__(self):
return self.timesteps
return self.config.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment