Unverified Commit 59f0ce82 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dance Diffusion] Better naming (#981)

uP
parent 365ff8f7
...@@ -47,7 +47,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -47,7 +47,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
batch_size: int = 1, batch_size: int = 1,
num_inference_steps: int = 100, num_inference_steps: int = 100,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
sample_length_in_s: Optional[float] = None, audio_length_in_s: Optional[float] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[AudioPipelineOutput, Tuple]: ) -> Union[AudioPipelineOutput, Tuple]:
r""" r"""
...@@ -60,6 +60,9 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -60,6 +60,9 @@ class DanceDiffusionPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
...@@ -69,23 +72,23 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -69,23 +72,23 @@ class DanceDiffusionPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if sample_length_in_s is None: if audio_length_in_s is None:
sample_length_in_s = self.unet.sample_size / self.unet.sample_rate audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
sample_size = sample_length_in_s * self.unet.sample_rate sample_size = audio_length_in_s * self.unet.sample_rate
down_scale_factor = 2 ** len(self.unet.up_blocks) down_scale_factor = 2 ** len(self.unet.up_blocks)
if sample_size < 3 * down_scale_factor: if sample_size < 3 * down_scale_factor:
raise ValueError( raise ValueError(
f"{sample_length_in_s} is too small. Make sure it's bigger or equal to" f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
f" {3 * down_scale_factor / self.unet.sample_rate}." f" {3 * down_scale_factor / self.unet.sample_rate}."
) )
original_sample_size = int(sample_size) original_sample_size = int(sample_size)
if sample_size % down_scale_factor != 0: if sample_size % down_scale_factor != 0:
sample_size = ((sample_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
logger.info( logger.info(
f"{sample_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
" process." " process."
) )
......
...@@ -91,7 +91,7 @@ class PipelineIntegrationTests(unittest.TestCase): ...@@ -91,7 +91,7 @@ class PipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096) output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
audio = output.audios audio = output.audios
audio_slice = audio[0, -3:, -3:] audio_slice = audio[0, -3:, -3:]
...@@ -108,7 +108,7 @@ class PipelineIntegrationTests(unittest.TestCase): ...@@ -108,7 +108,7 @@ class PipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096) output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
audio = output.audios audio = output.audios
audio_slice = audio[0, -3:, -3:] audio_slice = audio[0, -3:, -3:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment