Unverified Commit e3d71ad8 authored by Omar Sanseviero's avatar Omar Sanseviero Committed by GitHub
Browse files

Minor nits to Dance DIffusion (#4012)

Update pipeline_dance_diffusion.py
parent 68f61a07
......@@ -30,9 +30,9 @@ class DanceDiffusionPipeline(DiffusionPipeline):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image.
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded audio.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
A scheduler to be used in combination with `unet` to denoise the encoded audio. Can be one of
[`IPNDMScheduler`].
"""
......@@ -54,7 +54,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
batch_size (`int`, *optional*, defaults to 1):
The number of audio samples to generate.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at
the expense of slower inference.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
......@@ -67,7 +67,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
Returns:
[`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated audio.
"""
if audio_length_in_s is None:
......@@ -94,7 +94,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
)
sample_size = int(sample_size)
dtype = next(iter(self.unet.parameters())).dtype
dtype = next(self.unet.parameters()).dtype
shape = (batch_size, self.unet.config.in_channels, sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
......@@ -112,7 +112,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
# 1. predict noise model_output
model_output = self.unet(audio, t).sample
# 2. compute previous image: x_t -> t_t-1
# 2. compute previous audio sample: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample
audio = audio.clamp(-1, 1).float().cpu().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment