Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a785992c
Unverified
Commit
a785992c
authored
Jul 09, 2024
by
Sayak Paul
Committed by
GitHub
Jul 09, 2024
Browse files
[Tests] fix more sharding tests (#8797)
* fix * fix * ugly * okay * fix more * fix oops
parent
35cc66dc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+5
-3
No files found.
tests/models/test_modeling_common.py
View file @
a785992c
...
@@ -885,11 +885,11 @@ class ModelTesterMixin:
...
@@ -885,11 +885,11 @@ class ModelTesterMixin:
@
require_torch_gpu
@
require_torch_gpu
def
test_sharded_checkpoints
(
self
):
def
test_sharded_checkpoints
(
self
):
torch
.
manual_seed
(
0
)
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
config
).
eval
()
model
=
self
.
model_class
(
**
config
).
eval
()
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
...
@@ -909,6 +909,7 @@ class ModelTesterMixin:
...
@@ -909,6 +909,7 @@ class ModelTesterMixin:
new_model
=
new_model
.
to
(
torch_device
)
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
"generator"
in
inputs_dict
:
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
new_output
=
new_model
(
**
inputs_dict
)
...
@@ -942,6 +943,7 @@ class ModelTesterMixin:
...
@@ -942,6 +943,7 @@ class ModelTesterMixin:
new_model
=
new_model
.
to
(
torch_device
)
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
"generator"
in
inputs_dict
:
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment