Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8a3f0c1f
Unverified
Commit
8a3f0c1f
authored
Jan 16, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 16, 2023
Browse files
[Conversion] Improve safetensors (#1989)
parent
f6a5c359
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
4 deletions
+17
-4
scripts/convert_original_stable_diffusion_to_diffusers.py
scripts/convert_original_stable_diffusion_to_diffusers.py
+17
-4
No files found.
scripts/convert_original_stable_diffusion_to_diffusers.py
View file @
8a3f0c1f
...
...
@@ -20,6 +20,8 @@ import re
import
torch
from
safetensors
import
safe_open
try
:
from
omegaconf
import
OmegaConf
...
...
@@ -839,6 +841,11 @@ if __name__ == "__main__":
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser
.
add_argument
(
"--from_safetensors"
,
action
=
"store_true"
,
help
=
"If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch."
,
)
parser
.
add_argument
(
"--upcast_attention"
,
default
=
False
,
...
...
@@ -855,11 +862,17 @@ if __name__ == "__main__":
image_size
=
args
.
image_size
prediction_type
=
args
.
prediction_type
if
args
.
device
is
None
:
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
,
map_location
=
device
)
if
args
.
from_safetensors
:
checkpoint
=
{}
with
safe_open
(
args
.
checkpoint_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
checkpoint
[
key
]
=
f
.
get_tensor
(
key
)
else
:
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
,
map_location
=
args
.
device
)
if
args
.
device
is
None
:
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
,
map_location
=
device
)
else
:
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
,
map_location
=
args
.
device
)
# Sometimes models don't have the global_step item
if
"global_step"
in
checkpoint
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment