Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
85c4a326
Unverified
Commit
85c4a326
authored
Jul 05, 2024
by
Dhruv Nair
Committed by
GitHub
Jul 05, 2024
Browse files
Fix saving text encoder weights and kohya weights in advanced dreambooth lora script (#8766)
* update * update * update
parent
0bab9d6b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
2 deletions
+3
-2
examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
...diffusion_training/train_dreambooth_lora_sd15_advanced.py
+2
-1
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
...diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+1
-1
No files found.
examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
View file @
85c4a326
...
@@ -1290,6 +1290,7 @@ def main(args):
...
@@ -1290,6 +1290,7 @@ def main(args):
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
# make sure to pop weight so that corresponding model is not saved again
# make sure to pop weight so that corresponding model is not saved again
...
@@ -1981,7 +1982,7 @@ def main(args):
...
@@ -1981,7 +1982,7 @@ def main(args):
lora_state_dict
=
load_file
(
f
"
{
args
.
output_dir
}
/pytorch_lora_weights.safetensors"
)
lora_state_dict
=
load_file
(
f
"
{
args
.
output_dir
}
/pytorch_lora_weights.safetensors"
)
peft_state_dict
=
convert_all_state_dict_to_peft
(
lora_state_dict
)
peft_state_dict
=
convert_all_state_dict_to_peft
(
lora_state_dict
)
kohya_state_dict
=
convert_state_dict_to_kohya
(
peft_state_dict
)
kohya_state_dict
=
convert_state_dict_to_kohya
(
peft_state_dict
)
save_file
(
kohya_state_dict
,
f
"
{
args
.
output_dir
}
/
{
args
.
output_dir
}
.safetensors"
)
save_file
(
kohya_state_dict
,
f
"
{
args
.
output_dir
}
/
{
Path
(
args
.
output_dir
).
name
}
.safetensors"
)
save_model_card
(
save_model_card
(
model_id
if
not
args
.
push_to_hub
else
repo_id
,
model_id
if
not
args
.
push_to_hub
else
repo_id
,
...
...
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
View file @
85c4a326
...
@@ -2425,7 +2425,7 @@ def main(args):
...
@@ -2425,7 +2425,7 @@ def main(args):
lora_state_dict
=
load_file
(
f
"
{
args
.
output_dir
}
/pytorch_lora_weights.safetensors"
)
lora_state_dict
=
load_file
(
f
"
{
args
.
output_dir
}
/pytorch_lora_weights.safetensors"
)
peft_state_dict
=
convert_all_state_dict_to_peft
(
lora_state_dict
)
peft_state_dict
=
convert_all_state_dict_to_peft
(
lora_state_dict
)
kohya_state_dict
=
convert_state_dict_to_kohya
(
peft_state_dict
)
kohya_state_dict
=
convert_state_dict_to_kohya
(
peft_state_dict
)
save_file
(
kohya_state_dict
,
f
"
{
args
.
output_dir
}
/
{
args
.
output_dir
}
.safetensors"
)
save_file
(
kohya_state_dict
,
f
"
{
args
.
output_dir
}
/
{
Path
(
args
.
output_dir
).
name
}
.safetensors"
)
save_model_card
(
save_model_card
(
model_id
if
not
args
.
push_to_hub
else
repo_id
,
model_id
if
not
args
.
push_to_hub
else
repo_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment