Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
59f0ce82
Unverified
Commit
59f0ce82
authored
Oct 25, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 25, 2022
Browse files
[Dance Diffusion] Better naming (#981)
uP
parent
365ff8f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
9 deletions
+12
-9
src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
...ers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+10
-7
tests/pipelines/dance_diffusion/test_dance_diffusion.py
tests/pipelines/dance_diffusion/test_dance_diffusion.py
+2
-2
No files found.
src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
View file @
59f0ce82
...
...
@@ -47,7 +47,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
batch_size
:
int
=
1
,
num_inference_steps
:
int
=
100
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
sample
_length_in_s
:
Optional
[
float
]
=
None
,
audio
_length_in_s
:
Optional
[
float
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
AudioPipelineOutput
,
Tuple
]:
r
"""
...
...
@@ -60,6 +60,9 @@ class DanceDiffusionPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
...
...
@@ -69,23 +72,23 @@ class DanceDiffusionPipeline(DiffusionPipeline):
generated images.
"""
if
sample
_length_in_s
is
None
:
sample
_length_in_s
=
self
.
unet
.
sample_size
/
self
.
unet
.
sample_rate
if
audio
_length_in_s
is
None
:
audio
_length_in_s
=
self
.
unet
.
config
.
sample_size
/
self
.
unet
.
config
.
sample_rate
sample_size
=
sample
_length_in_s
*
self
.
unet
.
sample_rate
sample_size
=
audio
_length_in_s
*
self
.
unet
.
sample_rate
down_scale_factor
=
2
**
len
(
self
.
unet
.
up_blocks
)
if
sample_size
<
3
*
down_scale_factor
:
raise
ValueError
(
f
"
{
sample
_length_in_s
}
is too small. Make sure it's bigger or equal to"
f
"
{
audio
_length_in_s
}
is too small. Make sure it's bigger or equal to"
f
"
{
3
*
down_scale_factor
/
self
.
unet
.
sample_rate
}
."
)
original_sample_size
=
int
(
sample_size
)
if
sample_size
%
down_scale_factor
!=
0
:
sample_size
=
((
sample
_length_in_s
*
self
.
unet
.
sample_rate
)
//
down_scale_factor
+
1
)
*
down_scale_factor
sample_size
=
((
audio
_length_in_s
*
self
.
unet
.
sample_rate
)
//
down_scale_factor
+
1
)
*
down_scale_factor
logger
.
info
(
f
"
{
sample
_length_in_s
}
is increased to
{
sample_size
/
self
.
unet
.
sample_rate
}
so that it can be handled"
f
"
{
audio
_length_in_s
}
is increased to
{
sample_size
/
self
.
unet
.
sample_rate
}
so that it can be handled"
f
" by the model. It will be cut to
{
original_sample_size
/
self
.
unet
.
sample_rate
}
after the denoising"
" process."
)
...
...
tests/pipelines/dance_diffusion/test_dance_diffusion.py
View file @
59f0ce82
...
...
@@ -91,7 +91,7 @@ class PipelineIntegrationTests(unittest.TestCase):
pipe
.
set_progress_bar_config
(
disable
=
None
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
pipe
(
generator
=
generator
,
num_inference_steps
=
100
,
sample
_length_in_s
=
4.096
)
output
=
pipe
(
generator
=
generator
,
num_inference_steps
=
100
,
audio
_length_in_s
=
4.096
)
audio
=
output
.
audios
audio_slice
=
audio
[
0
,
-
3
:,
-
3
:]
...
...
@@ -108,7 +108,7 @@ class PipelineIntegrationTests(unittest.TestCase):
pipe
.
set_progress_bar_config
(
disable
=
None
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
pipe
(
generator
=
generator
,
num_inference_steps
=
100
,
sample
_length_in_s
=
4.096
)
output
=
pipe
(
generator
=
generator
,
num_inference_steps
=
100
,
audio
_length_in_s
=
4.096
)
audio
=
output
.
audios
audio_slice
=
audio
[
0
,
-
3
:,
-
3
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment