Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0244e2af
Commit
0244e2af
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
correct diffusion test
parent
3a177754
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-2
No files found.
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
0244e2af
...
@@ -471,7 +471,7 @@ class GradTTSPipeline(DiffusionPipeline):
...
@@ -471,7 +471,7 @@ class GradTTSPipeline(DiffusionPipeline):
mu_y
=
mu_y
.
transpose
(
1
,
2
)
mu_y
=
mu_y
.
transpose
(
1
,
2
)
# Sample latent representation from terminal distribution N(mu_y, I)
# Sample latent representation from terminal distribution N(mu_y, I)
z
=
mu_y
+
torch
.
randn
(
mu_y
.
shape
,
device
=
mu_y
.
device
,
generator
=
generator
)
/
temperature
z
=
mu_y
+
torch
.
randn
(
mu_y
.
shape
,
generator
=
generator
)
.
to
(
mu_y
.
device
)
xt
=
z
*
y_mask
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
h
=
1.0
/
num_inference_steps
...
...
tests/test_modeling_utils.py
View file @
0244e2af
...
@@ -714,9 +714,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -714,9 +714,9 @@ class PipelineTesterMixin(unittest.TestCase):
assert
mel_spec
.
shape
==
(
1
,
80
,
143
)
assert
mel_spec
.
shape
==
(
1
,
80
,
143
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
torch
.
tensor
(
[
-
6.
6119
,
-
6.
5963
,
-
6.
2776
,
-
6.7496
,
-
6.
7096
,
-
6.
5131
,
-
6.464
3
,
-
6.
481
7
,
-
6.
7185
]
[
-
6.
7584
,
-
6.
8347
,
-
6.
3293
,
-
6.
6437
,
-
6.
7233
,
-
6.46
8
4
,
-
6.
118
7
,
-
6.
3172
,
-
6.6890
]
)
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment