Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
20ce68f9
Unverified
Commit
20ce68f9
authored
Nov 30, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 30, 2022
Browse files
Fix dtype model loading (#1449)
* Add test * up * no bfloat16 for mps * fix * rename test
parent
110ffe25
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
47 additions
and
14 deletions
+47
-14
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+15
-0
tests/models/test_models_unet_1d.py
tests/models/test_models_unet_1d.py
+4
-4
tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py
...ines/versatile_diffusion/test_versatile_diffusion_mega.py
+1
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+19
-1
tests/test_pipelines.py
tests/test_pipelines.py
+1
-1
tests/test_scheduler.py
tests/test_scheduler.py
+4
-4
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+3
-3
No files found.
src/diffusers/modeling_utils.py
View file @
20ce68f9
...
@@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
...
@@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
model
=
cls
.
from_config
(
config
,
**
unused_kwargs
)
model
=
cls
.
from_config
(
config
,
**
unused_kwargs
)
state_dict
=
load_state_dict
(
model_file
)
state_dict
=
load_state_dict
(
model_file
)
dtype
=
set
(
v
.
dtype
for
v
in
state_dict
.
values
())
if
len
(
dtype
)
>
1
and
torch
.
float32
not
in
dtype
:
raise
ValueError
(
f
"The weights of the model file
{
model_file
}
have a mixture of incompatible dtypes
{
dtype
}
. Please"
f
" make sure that
{
model_file
}
weights have only one dtype."
)
elif
len
(
dtype
)
>
1
and
torch
.
float32
in
dtype
:
dtype
=
torch
.
float32
else
:
dtype
=
dtype
.
pop
()
# move model to correct dtype
model
=
model
.
to
(
dtype
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
model
,
state_dict
,
state_dict
,
...
...
tests/models/test_models_unet_1d.py
View file @
20ce68f9
...
@@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
super
().
test_outputs_equivalence
()
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
super
().
test_from_
pretrained_
save_pretrained
()
super
().
test_from_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
super
().
test_outputs_equivalence
()
super
().
test_outputs_equivalence
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
super
().
test_from_
pretrained_
save_pretrained
()
super
().
test_from_save_pretrained
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"mish op not supported in MPS"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
...
tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py
View file @
20ce68f9
...
@@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pipe
=
VersatileDiffusionPipeline
.
from_pretrained
(
"shi-labs/versatile-diffusion"
,
torch_dtype
=
torch
.
float16
)
pipe
=
VersatileDiffusionPipeline
.
from_pretrained
(
"shi-labs/versatile-diffusion"
,
torch_dtype
=
torch
.
float16
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/test_modeling_common.py
View file @
20ce68f9
...
@@ -27,7 +27,7 @@ from diffusers.utils import torch_device
...
@@ -27,7 +27,7 @@ from diffusers.utils import torch_device
class
ModelTesterMixin
:
class
ModelTesterMixin
:
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -57,6 +57,24 @@ class ModelTesterMixin:
...
@@ -57,6 +57,24 @@ class ModelTesterMixin:
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
self
.
assertLessEqual
(
max_diff
,
5e-5
,
"Models give different forward passes"
)
def
test_from_save_pretrained_dtype
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
if
torch_device
==
"mps"
and
dtype
==
torch
.
bfloat16
:
continue
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
to
(
dtype
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
,
low_cpu_mem_usage
=
True
)
assert
new_model
.
dtype
==
dtype
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
,
low_cpu_mem_usage
=
False
)
assert
new_model
.
dtype
==
dtype
def
test_determinism
(
self
):
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
...
tests/test_pipelines.py
View file @
20ce68f9
...
@@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
...
@@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
==
"Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.
\n
"
==
"Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.
\n
"
)
)
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNet2DModel
(
model
=
UNet2DModel
(
block_out_channels
=
(
32
,
64
),
block_out_channels
=
(
32
,
64
),
...
...
tests/test_scheduler.py
View file @
20ce68f9
...
@@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
@@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
@@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
...
@@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
torch
.
sum
(
torch
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
...
...
tests/test_scheduler_flax.py
View file @
20ce68f9
...
@@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
...
@@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
jnp
.
sum
(
jnp
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_
pretrained_
save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
pass
pass
def
test_scheduler_outputs_equivalence
(
self
):
def
test_scheduler_outputs_equivalence
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment