Commit 2d1f7de2 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents bc72d297 77c80489
...@@ -97,7 +97,7 @@ superres_model = GLIDESuperResUNetModel( ...@@ -97,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02) upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt")
glide = GLIDE( glide = GLIDE(
text_unet=text2im_model, text_unet=text2im_model,
......
...@@ -30,7 +30,6 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo ...@@ -30,7 +30,6 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ( from transformers.utils import (
ModelOutput, ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
...@@ -872,31 +871,26 @@ class GLIDE(DiffusionPipeline): ...@@ -872,31 +871,26 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (batch_size, self.upscale_unet.in_channels // 2, self.upscale_unet.resolution, self.upscale_unet.resolution),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device) * upsample_temp
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf num_trained_timesteps = self.upscale_noise_scheduler.timesteps
# Ideally, read DDIM paper in-detail understanding inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
# adapt the beta schedule to the number of steps
# Notation (<variable name> -> <name in paper> # self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res) model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1) noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step( pred_prev_image = self.upscale_noise_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
) )
# 3. optionally sample variance # 3. optionally sample variance
...@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline): ...@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
image = image.permute(0, 2, 3, 1) image = image.clamp(-1, 1).permute(0, 2, 3, 1)
return image return image
...@@ -74,14 +74,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -74,14 +74,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# #
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # self.register_buffer("log_variance", log_variance.to(torch.float32))
def rescale_betas(self, num_timesteps): # def rescale_betas(self, num_timesteps):
if self.beta_schedule == "linear": # # GLIDE scaling
scale = self.timesteps / num_timesteps # if self.beta_schedule == "linear":
self.betas = linear_beta_schedule( # scale = self.timesteps / num_timesteps
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale # self.betas = linear_beta_schedule(
) # num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
self.alphas = 1.0 - self.betas # )
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -112,7 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,7 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, image, t, num_inference_steps, eta): def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -146,6 +147,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -146,6 +147,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self.get_variance(t, num_inference_steps) variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment