Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
31adeb41
"vscode:/vscode.git/clone" did not exist on "d76af4d4c12bc0fe121b4530102f8f15964e98cc"
Unverified
Commit
31adeb41
authored
Jul 04, 2024
by
Sayak Paul
Committed by
GitHub
Jul 04, 2024
Browse files
[Tests] fix sharding tests (#8764)
fix sharding tests
parent
a7b9634e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
5 deletions
+10
-5
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+2
-1
tests/models/autoencoders/test_models_vae.py
tests/models/autoencoders/test_models_vae.py
+4
-3
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+4
-1
No files found.
src/diffusers/models/embeddings.py
View file @
31adeb41
...
...
@@ -415,9 +415,10 @@ class GaussianFourierProjection(nn.Module):
if
set_W_to_weight
:
# to delete later
del
self
.
weight
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
self
.
weight
=
self
.
W
del
self
.
W
def
forward
(
self
,
x
):
if
self
.
log
:
...
...
tests/models/autoencoders/test_models_vae.py
View file @
31adeb41
...
...
@@ -361,9 +361,10 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
forward_requires_fresh_args
=
True
def
inputs_dict
(
self
,
seed
=
None
):
generator
=
torch
.
Generator
(
"cpu"
)
if
seed
is
not
None
:
generator
.
manual_seed
(
0
)
if
seed
is
None
:
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
seed
)
image
=
randn_tensor
((
4
,
3
,
32
,
32
),
generator
=
generator
,
device
=
torch
.
device
(
torch_device
))
return
{
"sample"
:
image
,
"generator"
:
generator
}
...
...
tests/models/test_modeling_common.py
View file @
31adeb41
...
...
@@ -905,11 +905,13 @@ class ModelTesterMixin:
actual_num_shards
=
len
([
file
for
file
in
os
.
listdir
(
tmp_dir
)
if
file
.
endswith
(
".safetensors"
)])
self
.
assertTrue
(
actual_num_shards
==
expected_num_shards
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmp_dir
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmp_dir
)
.
eval
()
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_torch_gpu
...
...
@@ -940,6 +942,7 @@ class ModelTesterMixin:
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment