Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c11de135
Unverified
Commit
c11de135
authored
Jan 16, 2024
by
Aryan V S
Committed by
GitHub
Jan 16, 2024
Browse files
[training] fix training resuming problem for fp16 (SD LoRA DreamBooth) (#6554)
* fix training resume * update * update
parent
357855f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
6 deletions
+44
-6
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+44
-6
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
c11de135
...
...
@@ -35,7 +35,7 @@ from huggingface_hub import create_repo, upload_folder
from
huggingface_hub.utils
import
insecure_hashlib
from
packaging
import
version
from
peft
import
LoraConfig
from
peft.utils
import
get_peft_model_state_dict
from
peft.utils
import
get_peft_model_state_dict
,
set_peft_model_state_dict
from
PIL
import
Image
from
PIL.ImageOps
import
exif_transpose
from
torch.utils.data
import
Dataset
...
...
@@ -54,7 +54,13 @@ from diffusers import (
)
from
diffusers.loaders
import
LoraLoaderMixin
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.training_utils
import
_set_state_dict_into_text_encoder
,
cast_training_params
from
diffusers.utils
import
(
check_min_version
,
convert_state_dict_to_diffusers
,
convert_unet_state_dict_to_peft
,
is_wandb_available
,
)
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
...
...
@@ -892,10 +898,33 @@ def main(args):
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
lora_state_dict
,
network_alphas
=
LoraLoaderMixin
.
lora_state_dict
(
input_dir
)
LoraLoaderMixin
.
load_lora_into_unet
(
lora_state_dict
,
network_alphas
=
network_alphas
,
unet
=
unet_
)
LoraLoaderMixin
.
load_lora_into_text_encoder
(
lora_state_dict
,
network_alphas
=
network_alphas
,
text_encoder
=
text_encoder_
)
unet_state_dict
=
{
f
'
{
k
.
replace
(
"unet."
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
# check only for unexpected keys
unexpected_keys
=
getattr
(
incompatible_keys
,
"unexpected_keys"
,
None
)
if
unexpected_keys
:
logger
.
warning
(
f
"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f
"
{
unexpected_keys
}
. "
)
if
args
.
train_text_encoder
:
_set_state_dict_into_text_encoder
(
lora_state_dict
,
prefix
=
"text_encoder."
,
text_encoder
=
text_encoder_
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if
args
.
mixed_precision
==
"fp16"
:
models
=
[
unet_
]
if
args
.
train_text_encoder
:
models
.
append
(
text_encoder_
)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
accelerator
.
register_save_state_pre_hook
(
save_model_hook
)
accelerator
.
register_load_state_pre_hook
(
load_model_hook
)
...
...
@@ -910,6 +939,15 @@ def main(args):
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
accelerator
.
num_processes
)
# Make sure the trainable params are in float32.
if
args
.
mixed_precision
==
"fp16"
:
models
=
[
unet
]
if
args
.
train_text_encoder
:
models
.
append
(
text_encoder
)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if
args
.
use_8bit_adam
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment