Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
b1b99b59
Commit
b1b99b59
authored
Jul 21, 2022
by
Patrick von Platen
Browse files
some more cleaning
parent
606ac57e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
0 deletions
+116
-0
scripts/change_naming_configs_and_checkpoints.py
scripts/change_naming_configs_and_checkpoints.py
+112
-0
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+4
-0
No files found.
scripts/change_naming_configs_and_checkpoints.py
0 → 100644
View file @
b1b99b59
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """
import
argparse
import
os
import
json
import
torch
from
diffusers
import
UNet2DModel
,
UNet2DConditionModel
from
transformers.file_utils
import
has_file
do_only_config
=
False
do_only_weights
=
True
do_only_renaming
=
False
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--repo_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
config_parameters_to_change
=
{
"image_size"
:
"sample_size"
,
"num_res_blocks"
:
"layers_per_block"
,
"block_channels"
:
"block_out_channels"
,
"down_blocks"
:
"down_block_types"
,
"up_blocks"
:
"up_block_types"
,
"downscale_freq_shift"
:
"freq_shift"
,
"resnet_num_groups"
:
"norm_num_groups"
,
"resnet_act_fn"
:
"act_fn"
,
"resnet_eps"
:
"norm_eps"
,
"num_head_channels"
:
"attention_head_dim"
,
}
key_parameters_to_change
=
{
"time_steps"
:
"time_proj"
,
"mid"
:
"mid_block"
,
"downsample_blocks"
:
"down_blocks"
,
"upsample_blocks"
:
"up_blocks"
,
}
subfolder
=
""
if
has_file
(
args
.
repo_path
,
"config.json"
)
else
"unet"
with
open
(
os
.
path
.
join
(
args
.
repo_path
,
subfolder
,
"config.json"
),
"r"
,
encoding
=
"utf-8"
)
as
reader
:
text
=
reader
.
read
()
config
=
json
.
loads
(
text
)
if
do_only_config
:
for
key
in
config_parameters_to_change
.
keys
():
config
.
pop
(
key
,
None
)
if
has_file
(
args
.
repo_path
,
"config.json"
):
model
=
UNet2DModel
(
**
config
)
else
:
class_name
=
UNet2DConditionModel
if
"ldm-text2im-large-256"
in
args
.
repo_path
else
UNet2DModel
model
=
class_name
(
**
config
)
if
do_only_config
:
model
.
save_config
(
os
.
path
.
join
(
args
.
repo_path
,
subfolder
))
config
=
dict
(
model
.
config
)
if
do_only_renaming
:
for
key
,
value
in
config_parameters_to_change
.
items
():
if
key
in
config
:
config
[
value
]
=
config
[
key
]
del
config
[
key
]
config
[
"down_block_types"
]
=
[
k
.
replace
(
"UNetRes"
,
""
)
for
k
in
config
[
"down_block_types"
]]
config
[
"up_block_types"
]
=
[
k
.
replace
(
"UNetRes"
,
""
)
for
k
in
config
[
"up_block_types"
]]
if
do_only_weights
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
args
.
repo_path
,
subfolder
,
"diffusion_pytorch_model.bin"
))
new_state_dict
=
{}
for
param_key
,
param_value
in
state_dict
.
items
():
if
param_key
.
endswith
(
".op.bias"
)
or
param_key
.
endswith
(
".op.weight"
):
continue
has_changed
=
False
for
key
,
new_key
in
key_parameters_to_change
.
items
():
if
not
has_changed
and
param_key
.
split
(
"."
)[
0
]
==
key
:
new_state_dict
[
"."
.
join
([
new_key
]
+
param_key
.
split
(
"."
)[
1
:])]
=
param_value
has_changed
=
True
if
not
has_changed
:
new_state_dict
[
param_key
]
=
param_value
model
.
load_state_dict
(
new_state_dict
)
model
.
save_pretrained
(
os
.
path
.
join
(
args
.
repo_path
,
subfolder
))
src/diffusers/configuration_utils.py
View file @
b1b99b59
...
@@ -48,6 +48,7 @@ class ConfigMixin:
...
@@ -48,6 +48,7 @@ class ConfigMixin:
"""
"""
config_name
=
None
config_name
=
None
ignore_for_config
=
[]
def
register_to_config
(
self
,
**
kwargs
):
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
...
@@ -212,6 +213,9 @@ class ConfigMixin:
...
@@ -212,6 +213,9 @@ class ConfigMixin:
# remove general kwargs if present in dict
# remove general kwargs if present in dict
if
"kwargs"
in
expected_keys
:
if
"kwargs"
in
expected_keys
:
expected_keys
.
remove
(
"kwargs"
)
expected_keys
.
remove
(
"kwargs"
)
# remove keys to be ignored
if
len
(
cls
.
ignore_for_config
)
>
0
:
expected_keys
=
expected_keys
-
set
(
cls
.
ignore_for_config
)
init_dict
=
{}
init_dict
=
{}
for
key
in
expected_keys
:
for
key
in
expected_keys
:
if
key
in
kwargs
:
if
key
in
kwargs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment