Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2858d7e1
Unverified
Commit
2858d7e1
authored
May 17, 2023
by
Patrick von Platen
Committed by
GitHub
May 17, 2023
Browse files
[From ckpt] Fix from_ckpt (#3466)
* Correct from_ckpt * make style
parent
88295f92
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
10 deletions
+14
-10
src/diffusers/loaders.py
src/diffusers/loaders.py
+1
-1
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+13
-9
No files found.
src/diffusers/loaders.py
View file @
2858d7e1
...
...
@@ -1326,7 +1326,7 @@ class FromCkptMixin:
file_extension
=
pretrained_model_link_or_path
.
rsplit
(
"."
,
1
)[
-
1
]
from_safetensors
=
file_extension
==
"safetensors"
if
from_safetensors
and
use_safetensors
is
Tru
e
:
if
from_safetensors
and
use_safetensors
is
Fals
e
:
raise
ValueError
(
"Make sure to install `safetensors` with `pip install safetensors`."
)
# TODO: For now we only support stable diffusion
...
...
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
2858d7e1
...
...
@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item
=
new_item
.
replace
(
"norm.weight"
,
"group_norm.weight"
)
new_item
=
new_item
.
replace
(
"norm.bias"
,
"group_norm.bias"
)
new_item
=
new_item
.
replace
(
"q.weight"
,
"
query
.weight"
)
new_item
=
new_item
.
replace
(
"q.bias"
,
"
query
.bias"
)
new_item
=
new_item
.
replace
(
"q.weight"
,
"
to_q
.weight"
)
new_item
=
new_item
.
replace
(
"q.bias"
,
"
to_q
.bias"
)
new_item
=
new_item
.
replace
(
"k.weight"
,
"
key
.weight"
)
new_item
=
new_item
.
replace
(
"k.bias"
,
"
key
.bias"
)
new_item
=
new_item
.
replace
(
"k.weight"
,
"
to_k
.weight"
)
new_item
=
new_item
.
replace
(
"k.bias"
,
"
to_k
.bias"
)
new_item
=
new_item
.
replace
(
"v.weight"
,
"
value
.weight"
)
new_item
=
new_item
.
replace
(
"v.bias"
,
"
value
.bias"
)
new_item
=
new_item
.
replace
(
"v.weight"
,
"
to_v
.weight"
)
new_item
=
new_item
.
replace
(
"v.bias"
,
"
to_v
.bias"
)
new_item
=
new_item
.
replace
(
"proj_out.weight"
,
"
proj_attn
.weight"
)
new_item
=
new_item
.
replace
(
"proj_out.bias"
,
"
proj_attn
.bias"
)
new_item
=
new_item
.
replace
(
"proj_out.weight"
,
"
to_out.0
.weight"
)
new_item
=
new_item
.
replace
(
"proj_out.bias"
,
"
to_out.0
.bias"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
...
...
@@ -204,8 +204,12 @@ def assign_to_checkpoint(
new_path
=
new_path
.
replace
(
replacement
[
"old"
],
replacement
[
"new"
])
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
is_attn_weight
=
"proj_attn.weight"
in
new_path
or
(
"attentions"
in
new_path
and
"to_"
in
new_path
)
shape
=
old_checkpoint
[
path
[
"old"
]].
shape
if
is_attn_weight
and
len
(
shape
)
==
3
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]][:,
:,
0
]
elif
is_attn_weight
and
len
(
shape
)
==
4
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]][:,
:,
0
,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment