Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
01782c22
Unverified
Commit
01782c22
authored
Nov 29, 2023
by
Kashif Rasul
Committed by
GitHub
Nov 29, 2023
Browse files
[Wuerstchen] Adapt lora training example scripts to use PEFT (#5959)
* Adapt lora example scripts to use PEFT * add to_out.0
parent
d63a498c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
24 deletions
+67
-24
examples/wuerstchen/text_to_image/requirements.txt
examples/wuerstchen/text_to_image/requirements.txt
+1
-0
examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
...uerstchen/text_to_image/train_text_to_image_lora_prior.py
+66
-24
No files found.
examples/wuerstchen/text_to_image/requirements.txt
View file @
01782c22
...
...
@@ -5,3 +5,4 @@ wandb
huggingface-cli
bitsandbytes
deepspeed
peft>=0.6.0
examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
View file @
01782c22
...
...
@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from
datasets
import
load_dataset
from
huggingface_hub
import
create_repo
,
hf_hub_download
,
upload_folder
from
modeling_efficient_net_encoder
import
EfficientNetEncoder
from
peft
import
LoraConfig
from
peft.utils
import
get_peft_model_state_dict
from
torchvision
import
transforms
from
tqdm
import
tqdm
from
transformers
import
CLIPTextModel
,
PreTrainedTokenizerFast
from
transformers.utils
import
ContextManagers
from
diffusers
import
AutoPipelineForText2Image
,
DDPMWuerstchenScheduler
,
WuerstchenPriorPipeline
from
diffusers.loaders
import
AttnProcsLayers
from
diffusers.models.attention_processor
import
LoRAAttnProcessor
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.wuerstchen
import
DEFAULT_STAGE_C_TIMESTEPS
,
WuerstchenPrior
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
...
...
@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
f
.
write
(
yaml
+
model_card
)
def
log_validation
(
text_encoder
,
tokenizer
,
attn_process
or
s
,
args
,
accelerator
,
weight_dtype
,
epoch
):
def
log_validation
(
text_encoder
,
tokenizer
,
pri
or
,
args
,
accelerator
,
weight_dtype
,
epoch
):
logger
.
info
(
"Running validation... "
)
pipeline
=
AutoPipelineForText2Image
.
from_pretrained
(
args
.
pretrained_decoder_model_name_or_path
,
prior
=
accelerator
.
unwrap_model
(
prior
),
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_tokenizer
=
tokenizer
,
torch_dtype
=
weight_dtype
,
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
prior_prior
.
set_attn_processor
(
attn_processors
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
if
args
.
seed
is
None
:
...
...
@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images
=
[]
for
i
in
range
(
len
(
args
.
validation_prompts
)):
with
torch
.
autocast
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
image
=
pipeline
(
args
.
validation_prompts
[
i
],
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
...
...
@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height
=
args
.
resolution
,
width
=
args
.
resolution
,
).
images
[
0
]
images
.
append
(
image
)
for
tracker
in
accelerator
.
trackers
:
...
...
@@ -527,11 +526,50 @@ def main():
prior
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
# lora attn processor
lora_attn_procs
=
{}
for
name
in
prior
.
attn_processors
.
keys
():
lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
hidden_size
=
prior
.
config
[
"c"
],
rank
=
args
.
rank
)
prior
.
set_attn_processor
(
lora_attn_procs
)
lora_layers
=
AttnProcsLayers
(
prior
.
attn_processors
)
prior_lora_config
=
LoraConfig
(
r
=
args
.
rank
,
target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
,
"add_k_proj"
,
"add_v_proj"
]
)
prior
.
add_adapter
(
prior_lora_config
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
prior_lora_layers_to_save
=
None
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
prior
))):
prior_lora_layers_to_save
=
get_peft_model_state_dict
(
model
)
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
# make sure to pop weight so that corresponding model is not saved again
weights
.
pop
()
WuerstchenPriorPipeline
.
save_lora_weights
(
output_dir
,
unet_lora_layers
=
prior_lora_layers_to_save
,
)
def
load_model_hook
(
models
,
input_dir
):
prior_
=
None
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
prior
))):
prior_
=
model
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
lora_state_dict
,
network_alphas
=
WuerstchenPriorPipeline
.
lora_state_dict
(
input_dir
)
WuerstchenPriorPipeline
.
load_lora_into_unet
(
lora_state_dict
,
network_alphas
=
network_alphas
,
unet
=
prior_
)
WuerstchenPriorPipeline
.
load_lora_into_text_encoder
(
lora_state_dict
,
network_alphas
=
network_alphas
,
)
accelerator
.
register_save_state_pre_hook
(
save_model_hook
)
accelerator
.
register_load_state_pre_hook
(
load_model_hook
)
if
args
.
allow_tf32
:
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
...
@@ -547,8 +585,9 @@ def main():
optimizer_cls
=
bnb
.
optim
.
AdamW8bit
else
:
optimizer_cls
=
torch
.
optim
.
AdamW
params_to_optimize
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
prior
.
parameters
()))
optimizer
=
optimizer_cls
(
lora_layers
.
parameters
()
,
params_to_optimize
,
lr
=
args
.
learning_rate
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
weight_decay
=
args
.
adam_weight_decay
,
...
...
@@ -674,8 +713,8 @@ def main():
num_training_steps
=
args
.
max_train_steps
*
args
.
gradient_accumulation_steps
,
)
lora_layers
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
lora_layers
,
optimizer
,
train_dataloader
,
lr_scheduler
prior
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
prior
,
optimizer
,
train_dataloader
,
lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
...
...
@@ -782,7 +821,7 @@ def main():
# Backpropagate
accelerator
.
backward
(
loss
)
if
accelerator
.
sync_gradients
:
accelerator
.
clip_grad_norm_
(
lora_layers
.
parameters
()
,
args
.
max_grad_norm
)
accelerator
.
clip_grad_norm_
(
params_to_optimize
,
args
.
max_grad_norm
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
...
...
@@ -828,17 +867,19 @@ def main():
if
accelerator
.
is_main_process
:
if
args
.
validation_prompts
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
log_validation
(
text_encoder
,
tokenizer
,
prior
.
attn_processors
,
args
,
accelerator
,
weight_dtype
,
global_step
)
log_validation
(
text_encoder
,
tokenizer
,
prior
,
args
,
accelerator
,
weight_dtype
,
global_step
)
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
prior
=
accelerator
.
unwrap_model
(
prior
)
prior
=
prior
.
to
(
torch
.
float32
)
prior_lora_state_dict
=
get_peft_model_state_dict
(
prior
)
WuerstchenPriorPipeline
.
save_lora_weights
(
os
.
path
.
join
(
args
.
output_dir
,
"prior_lora"
),
unet_lora_layers
=
lora_layers
,
save_directory
=
args
.
output_dir
,
unet_lora_layers
=
prior_lora_state_dict
,
)
# Run a final round of inference.
...
...
@@ -849,11 +890,12 @@ def main():
args
.
pretrained_decoder_model_name_or_path
,
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_tokenizer
=
tokenizer
,
torch_dtype
=
weight_dtype
,
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
,
torch_dtype
=
weight_dtype
)
# load lora weights
pipeline
.
prior_pipe
.
load_lora_weights
(
os
.
path
.
join
(
args
.
output_dir
,
"prior_lora"
))
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
# load lora weights
pipeline
.
prior_pipe
.
load_lora_weights
(
args
.
output_dir
,
weight_name
=
"pytorch_lora_weights.safetensors"
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
if
args
.
seed
is
None
:
...
...
@@ -862,7 +904,7 @@ def main():
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
for
i
in
range
(
len
(
args
.
validation_prompts
)):
with
torch
.
autocast
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
image
=
pipeline
(
args
.
validation_prompts
[
i
],
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment