Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
2858d7e1
Unverified
Commit
2858d7e1
authored
May 17, 2023
by
Patrick von Platen
Committed by
GitHub
May 17, 2023
Browse files
[From ckpt] Fix from_ckpt (#3466)
* Correct from_ckpt * make style
parent
88295f92
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
10 deletions
+14
-10
src/diffusers/loaders.py
src/diffusers/loaders.py
+1
-1
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+13
-9
No files found.
src/diffusers/loaders.py
View file @
2858d7e1
...
...
@@ -1326,7 +1326,7 @@ class FromCkptMixin:
file_extension
=
pretrained_model_link_or_path
.
rsplit
(
"."
,
1
)[
-
1
]
from_safetensors
=
file_extension
==
"safetensors"
if
from_safetensors
and
use_safetensors
is
Tru
e
:
if
from_safetensors
and
use_safetensors
is
Fals
e
:
raise
ValueError
(
"Make sure to install `safetensors` with `pip install safetensors`."
)
# TODO: For now we only support stable diffusion
...
...
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
2858d7e1
...
...
@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item
=
new_item
.
replace
(
"norm.weight"
,
"group_norm.weight"
)
new_item
=
new_item
.
replace
(
"norm.bias"
,
"group_norm.bias"
)
new_item
=
new_item
.
replace
(
"q.weight"
,
"
query
.weight"
)
new_item
=
new_item
.
replace
(
"q.bias"
,
"
query
.bias"
)
new_item
=
new_item
.
replace
(
"q.weight"
,
"
to_q
.weight"
)
new_item
=
new_item
.
replace
(
"q.bias"
,
"
to_q
.bias"
)
new_item
=
new_item
.
replace
(
"k.weight"
,
"
key
.weight"
)
new_item
=
new_item
.
replace
(
"k.bias"
,
"
key
.bias"
)
new_item
=
new_item
.
replace
(
"k.weight"
,
"
to_k
.weight"
)
new_item
=
new_item
.
replace
(
"k.bias"
,
"
to_k
.bias"
)
new_item
=
new_item
.
replace
(
"v.weight"
,
"
value
.weight"
)
new_item
=
new_item
.
replace
(
"v.bias"
,
"
value
.bias"
)
new_item
=
new_item
.
replace
(
"v.weight"
,
"
to_v
.weight"
)
new_item
=
new_item
.
replace
(
"v.bias"
,
"
to_v
.bias"
)
new_item
=
new_item
.
replace
(
"proj_out.weight"
,
"
proj_attn
.weight"
)
new_item
=
new_item
.
replace
(
"proj_out.bias"
,
"
proj_attn
.bias"
)
new_item
=
new_item
.
replace
(
"proj_out.weight"
,
"
to_out.0
.weight"
)
new_item
=
new_item
.
replace
(
"proj_out.bias"
,
"
to_out.0
.bias"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
...
...
@@ -204,8 +204,12 @@ def assign_to_checkpoint(
new_path
=
new_path
.
replace
(
replacement
[
"old"
],
replacement
[
"new"
])
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
is_attn_weight
=
"proj_attn.weight"
in
new_path
or
(
"attentions"
in
new_path
and
"to_"
in
new_path
)
shape
=
old_checkpoint
[
path
[
"old"
]].
shape
if
is_attn_weight
and
len
(
shape
)
==
3
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]][:,
:,
0
]
elif
is_attn_weight
and
len
(
shape
)
==
4
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]][:,
:,
0
,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"old"
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment