Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6921393a
"vscode:/vscode.git/clone" did not exist on "b2b531e0412f1207a8d17ea24cfbece77490053e"
Commit
6921393a
authored
Jun 27, 2022
by
patil-suraj
Browse files
add fast test for ldm
parent
17bf65e1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
1 deletion
+16
-1
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+1
-0
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+0
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+15
-0
No files found.
src/diffusers/models/embeddings.py
View file @
6921393a
...
@@ -34,6 +34,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
...
@@ -34,6 +34,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
return
emb
# unet_glide.py
# unet_glide.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
"""
...
...
src/diffusers/models/unet_grad_tts.py
View file @
6921393a
...
@@ -198,7 +198,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -198,7 +198,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
not
isinstance
(
spk
,
type
(
None
)):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
timesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
time_pos_emb
(
timesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
t
=
self
.
mlp
(
t
)
...
...
tests/test_modeling_utils.py
View file @
6921393a
...
@@ -694,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -694,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ldm_text2img_fast
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
rensor
([
0.3163
,
0.8670
,
0.6465
,
0.1865
,
0.6291
,
0.5139
,
0.2824
,
0.3723
,
0.4344
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_glide_text2img
(
self
):
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
model_id
=
"fusing/glide-base"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment