Unverified Commit b63419a2 authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

AudioDiffusionPipeline - fix encode method after config changes (#3114)

* config fixes

* deprecate get_input_dims
parent eb29dbad
...@@ -51,21 +51,6 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -51,21 +51,6 @@ class AudioDiffusionPipeline(DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
def get_input_dims(self) -> Tuple:
"""Returns dimension of input image
Returns:
`Tuple`: (height, width)
"""
input_module = self.vqvae if self.vqvae is not None else self.unet
# For backwards compatibility
sample_size = (
(input_module.config.sample_size, input_module.config.sample_size)
if type(input_module.config.sample_size) == int
else input_module.config.sample_size
)
return sample_size
def get_default_steps(self) -> int: def get_default_steps(self) -> int:
"""Returns default number of steps recommended for inference """Returns default number of steps recommended for inference
...@@ -123,8 +108,6 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -123,8 +108,6 @@ class AudioDiffusionPipeline(DiffusionPipeline):
# For backwards compatibility # For backwards compatibility
if type(self.unet.config.sample_size) == int: if type(self.unet.config.sample_size) == int:
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = randn_tensor( noise = randn_tensor(
( (
...@@ -234,7 +217,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -234,7 +217,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
sample = torch.Tensor(sample).to(self.device) sample = torch.Tensor(sample).to(self.device)
for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t = self.scheduler.alphas_cumprod[t]
alpha_prod_t_prev = ( alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep] self.scheduler.alphas_cumprod[prev_timestep]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment