Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
651c5adf
Unverified
Commit
651c5adf
authored
Jan 16, 2023
by
蓝色的秋风
Committed by
GitHub
Jan 16, 2023
Browse files
[Conversion] Support convert diffusers to safetensors (#1996)
fix: support diffusers to safetensors
parent
cc2cc00d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
2 deletions
+11
-2
scripts/convert_diffusers_to_original_stable_diffusion.py
scripts/convert_diffusers_to_original_stable_diffusion.py
+11
-2
No files found.
scripts/convert_diffusers_to_original_stable_diffusion.py
View file @
651c5adf
...
@@ -8,6 +8,8 @@ import re
...
@@ -8,6 +8,8 @@ import re
import
torch
import
torch
from
safetensors.torch
import
save_file
# =================#
# =================#
# UNet Conversion #
# UNet Conversion #
...
@@ -266,6 +268,9 @@ if __name__ == "__main__":
...
@@ -266,6 +268,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--model_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the model to convert."
)
parser
.
add_argument
(
"--model_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the model to convert."
)
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--half"
,
action
=
"store_true"
,
help
=
"Save weights in half precision."
)
parser
.
add_argument
(
"--half"
,
action
=
"store_true"
,
help
=
"Save weights in half precision."
)
parser
.
add_argument
(
"--use_safetensors"
,
action
=
"store_true"
,
help
=
"Save weights use safetensors, default is ckpt."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -306,5 +311,9 @@ if __name__ == "__main__":
...
@@ -306,5 +311,9 @@ if __name__ == "__main__":
state_dict
=
{
**
unet_state_dict
,
**
vae_state_dict
,
**
text_enc_dict
}
state_dict
=
{
**
unet_state_dict
,
**
vae_state_dict
,
**
text_enc_dict
}
if
args
.
half
:
if
args
.
half
:
state_dict
=
{
k
:
v
.
half
()
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
half
()
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
"state_dict"
:
state_dict
}
torch
.
save
(
state_dict
,
args
.
checkpoint_path
)
if
args
.
use_safetensors
:
save_file
(
state_dict
,
args
.
checkpoint_path
)
else
:
state_dict
=
{
"state_dict"
:
state_dict
}
torch
.
save
(
state_dict
,
args
.
checkpoint_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment