Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
20ce68f9
Unverified
Commit
20ce68f9
authored
Nov 30, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 30, 2022
Browse files
Fix dtype model loading (#1449)
* Add test * up * no bfloat16 for mps * fix * rename test
parent
110ffe25
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
47 additions
and
14 deletions
+47
-14
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+15
-0
tests/models/test_models_unet_1d.py
tests/models/test_models_unet_1d.py
+4
-4
tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py
...ines/versatile_diffusion/test_versatile_diffusion_mega.py
+1
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+19
-1
tests/test_pipelines.py
tests/test_pipelines.py
+1
-1
tests/test_scheduler.py
tests/test_scheduler.py
+4
-4
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+3
-3
No files found.
src/diffusers/modeling_utils.py
View file @
20ce68f9
...
@@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
...
@@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
model
=
cls
.
from_config
(
config
,
**
unused_kwargs
)
model
=
cls
.
from_config
(
config
,
**
unused_kwargs
)
state_dict
=
load_state_dict
(
model_file
)
state_dict
=
load_state_dict
(
model_file
)
dtype
=
set
(
v
.
dtype
for
v
in
state_dict
.
values
())
if
len
(
dtype
)
>
1
and
torch
.
float32
not
in
dtype
:
raise
ValueError
(
f
"The weights of the model file
{
model_file
}
have a mixture of incompatible dtypes
{
dtype
}
. Please"
f
" make sure that
{
model_file
}
weights have only one dtype."
)
elif
len
(
dtype
)
>
1
and
torch
.
float32
in
dtype
:
dtype
=
torch
.
float32
else
:
dtype
=
dtype
.
pop
()
# move model to correct dtype
model
=
model
.
to
(
dtype
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
model
,
state_dict
,
state_dict
,
...
...
tests/models/test_models_unet_1d.py
View file @
20ce68f9
...
@@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
super
().
test_outputs_equivalence
()
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
super
().
test_from_
pretrained_
save_pretrained
()
super
().
test_from_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
super
().
test_outputs_equivalence
()
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
super
().
test_from_
pretrained_
save_pretrained
()
super
().
test_from_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
...
tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py
View file @
20ce68f9
...
@@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pipe
=
VersatileDiffusionPipeline
.
from_pretrained
(
"shi-labs/versatile-diffusion"
,
torch_dtype
=
torch
.
float16
)
pipe
=
VersatileDiffusionPipeline
.
from_pretrained
(
"shi-labs/versatile-diffusion"
,
torch_dtype
=
torch
.
float16
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/test_modeling_common.py
View file @
20ce68f9
...
@@ -27,7 +27,7 @@ from diffusers.utils import torch_device
...
@@ -27,7 +27,7 @@ from diffusers.utils import torch_device
class
ModelTesterMixin
:
class
ModelTesterMixin
:
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -57,6 +57,24 @@ class ModelTesterMixin:
...
@@ -57,6 +57,24 @@ class ModelTesterMixin:
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
def
test_from_save_pretrained_dtype
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
if
torch_device
==
"mps"
and
dtype
==
torch
.
bfloat16
:
continue
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
to
(
dtype
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
,
low_cpu_mem_usage
=
True
)
assert
new_model
.
dtype
==
dtype
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
,
low_cpu_mem_usage
=
False
)
assert
new_model
.
dtype
==
dtype
def
test_determinism
(
self
):
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
...
tests/test_pipelines.py
View file @
20ce68f9
...
@@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
...
@@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
==
"Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.
\n
"
==
"Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.
\n
"
)
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNet2DModel
(
model
=
UNet2DModel
(
block_out_channels
=
(
32
,
64
),
block_out_channels
=
(
32
,
64
),
...
...
tests/test_scheduler.py
View file @
20ce68f9
...
@@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
@@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
@@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
...
tests/test_scheduler_flax.py
View file @
20ce68f9
...
@@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
test_scheduler_outputs_equivalence
(
self
):
def
test_scheduler_outputs_equivalence
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment