Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
01782c22
Unverified
Commit
01782c22
authored
Nov 29, 2023
by
Kashif Rasul
Committed by
GitHub
Nov 29, 2023
Browse files
[Wuerstchen] Adapt lora training example scripts to use PEFT (#5959)
* Adapt lora example scripts to use PEFT * add to_out.0
parent
d63a498c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
24 deletions
+67
-24
examples/wuerstchen/text_to_image/requirements.txt
examples/wuerstchen/text_to_image/requirements.txt
+1
-0
examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
...uerstchen/text_to_image/train_text_to_image_lora_prior.py
+66
-24
No files found.
examples/wuerstchen/text_to_image/requirements.txt
View file @
01782c22
...
@@ -5,3 +5,4 @@ wandb
...
@@ -5,3 +5,4 @@ wandb
huggingface-cli
huggingface-cli
bitsandbytes
bitsandbytes
deepspeed
deepspeed
peft>=0.6.0
examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
View file @
01782c22
...
@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
...
@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
huggingface_hub
import
create_repo
,
hf_hub_download
,
upload_folder
from
huggingface_hub
import
create_repo
,
hf_hub_download
,
upload_folder
from
modeling_efficient_net_encoder
import
EfficientNetEncoder
from
modeling_efficient_net_encoder
import
EfficientNetEncoder
from
peft
import
LoraConfig
from
peft.utils
import
get_peft_model_state_dict
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
CLIPTextModel
,
PreTrainedTokenizerFast
from
transformers
import
CLIPTextModel
,
PreTrainedTokenizerFast
from
transformers.utils
import
ContextManagers
from
transformers.utils
import
ContextManagers
from
diffusers
import
AutoPipelineForText2Image
,
DDPMWuerstchenScheduler
,
WuerstchenPriorPipeline
from
diffusers
import
AutoPipelineForText2Image
,
DDPMWuerstchenScheduler
,
WuerstchenPriorPipeline
from
diffusers.loaders
import
AttnProcsLayers
from
diffusers.models.attention_processor
import
LoRAAttnProcessor
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.wuerstchen
import
DEFAULT_STAGE_C_TIMESTEPS
,
WuerstchenPrior
from
diffusers.pipelines.wuerstchen
import
DEFAULT_STAGE_C_TIMESTEPS
,
WuerstchenPrior
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
...
@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
...
@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
f
.
write
(
yaml
+
model_card
)
f
.
write
(
yaml
+
model_card
)
def
log_validation
(
text_encoder
,
tokenizer
,
attn_process
or
s
,
args
,
accelerator
,
weight_dtype
,
epoch
):
def
log_validation
(
text_encoder
,
tokenizer
,
pri
or
,
args
,
accelerator
,
weight_dtype
,
epoch
):
logger
.
info
(
"Running validation... "
)
logger
.
info
(
"Running validation... "
)
pipeline
=
AutoPipelineForText2Image
.
from_pretrained
(
pipeline
=
AutoPipelineForText2Image
.
from_pretrained
(
args
.
pretrained_decoder_model_name_or_path
,
args
.
pretrained_decoder_model_name_or_path
,
prior
=
accelerator
.
unwrap_model
(
prior
),
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_tokenizer
=
tokenizer
,
prior_tokenizer
=
tokenizer
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
prior_prior
.
set_attn_processor
(
attn_processors
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
if
args
.
seed
is
None
:
if
args
.
seed
is
None
:
...
@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
...
@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images
=
[]
images
=
[]
for
i
in
range
(
len
(
args
.
validation_prompts
)):
for
i
in
range
(
len
(
args
.
validation_prompts
)):
with
torch
.
autocast
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
image
=
pipeline
(
image
=
pipeline
(
args
.
validation_prompts
[
i
],
args
.
validation_prompts
[
i
],
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
...
@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
...
@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height
=
args
.
resolution
,
height
=
args
.
resolution
,
width
=
args
.
resolution
,
width
=
args
.
resolution
,
).
images
[
0
]
).
images
[
0
]
images
.
append
(
image
)
images
.
append
(
image
)
for
tracker
in
accelerator
.
trackers
:
for
tracker
in
accelerator
.
trackers
:
...
@@ -527,11 +526,50 @@ def main():
...
@@ -527,11 +526,50 @@ def main():
prior
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
prior
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
# lora attn processor
# lora attn processor
lora_attn_procs
=
{}
prior_lora_config
=
LoraConfig
(
for
name
in
prior
.
attn_processors
.
keys
():
r
=
args
.
rank
,
target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
,
"add_k_proj"
,
"add_v_proj"
]
lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
hidden_size
=
prior
.
config
[
"c"
],
rank
=
args
.
rank
)
)
prior
.
set_attn_processor
(
lora_attn_procs
)
prior
.
add_adapter
(
prior_lora_config
)
lora_layers
=
AttnProcsLayers
(
prior
.
attn_processors
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
prior_lora_layers_to_save
=
None
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
prior
))):
prior_lora_layers_to_save
=
get_peft_model_state_dict
(
model
)
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
# make sure to pop weight so that corresponding model is not saved again
weights
.
pop
()
WuerstchenPriorPipeline
.
save_lora_weights
(
output_dir
,
unet_lora_layers
=
prior_lora_layers_to_save
,
)
def
load_model_hook
(
models
,
input_dir
):
prior_
=
None
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
prior
))):
prior_
=
model
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
lora_state_dict
,
network_alphas
=
WuerstchenPriorPipeline
.
lora_state_dict
(
input_dir
)
WuerstchenPriorPipeline
.
load_lora_into_unet
(
lora_state_dict
,
network_alphas
=
network_alphas
,
unet
=
prior_
)
WuerstchenPriorPipeline
.
load_lora_into_text_encoder
(
lora_state_dict
,
network_alphas
=
network_alphas
,
)
accelerator
.
register_save_state_pre_hook
(
save_model_hook
)
accelerator
.
register_load_state_pre_hook
(
load_model_hook
)
if
args
.
allow_tf32
:
if
args
.
allow_tf32
:
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
@@ -547,8 +585,9 @@ def main():
...
@@ -547,8 +585,9 @@ def main():
optimizer_cls
=
bnb
.
optim
.
AdamW8bit
optimizer_cls
=
bnb
.
optim
.
AdamW8bit
else
:
else
:
optimizer_cls
=
torch
.
optim
.
AdamW
optimizer_cls
=
torch
.
optim
.
AdamW
params_to_optimize
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
prior
.
parameters
()))
optimizer
=
optimizer_cls
(
optimizer
=
optimizer_cls
(
lora_layers
.
parameters
()
,
params_to_optimize
,
lr
=
args
.
learning_rate
,
lr
=
args
.
learning_rate
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
weight_decay
=
args
.
adam_weight_decay
,
weight_decay
=
args
.
adam_weight_decay
,
...
@@ -674,8 +713,8 @@ def main():
...
@@ -674,8 +713,8 @@ def main():
num_training_steps
=
args
.
max_train_steps
*
args
.
gradient_accumulation_steps
,
num_training_steps
=
args
.
max_train_steps
*
args
.
gradient_accumulation_steps
,
)
)
lora_layers
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
prior
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
lora_layers
,
optimizer
,
train_dataloader
,
lr_scheduler
prior
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
...
@@ -782,7 +821,7 @@ def main():
...
@@ -782,7 +821,7 @@ def main():
# Backpropagate
# Backpropagate
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
if
accelerator
.
sync_gradients
:
if
accelerator
.
sync_gradients
:
accelerator
.
clip_grad_norm_
(
lora_layers
.
parameters
()
,
args
.
max_grad_norm
)
accelerator
.
clip_grad_norm_
(
params_to_optimize
,
args
.
max_grad_norm
)
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -828,17 +867,19 @@ def main():
...
@@ -828,17 +867,19 @@ def main():
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
if
args
.
validation_prompts
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
if
args
.
validation_prompts
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
log_validation
(
log_validation
(
text_encoder
,
tokenizer
,
prior
,
args
,
accelerator
,
weight_dtype
,
global_step
)
text_encoder
,
tokenizer
,
prior
.
attn_processors
,
args
,
accelerator
,
weight_dtype
,
global_step
)
# Create the pipeline using the trained modules and save it.
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
prior
=
accelerator
.
unwrap_model
(
prior
)
prior
=
prior
.
to
(
torch
.
float32
)
prior
=
prior
.
to
(
torch
.
float32
)
prior_lora_state_dict
=
get_peft_model_state_dict
(
prior
)
WuerstchenPriorPipeline
.
save_lora_weights
(
WuerstchenPriorPipeline
.
save_lora_weights
(
os
.
path
.
join
(
args
.
output_dir
,
"prior_lora"
),
save_directory
=
args
.
output_dir
,
unet_lora_layers
=
lora_layers
,
unet_lora_layers
=
prior_lora_state_dict
,
)
)
# Run a final round of inference.
# Run a final round of inference.
...
@@ -849,11 +890,12 @@ def main():
...
@@ -849,11 +890,12 @@ def main():
args
.
pretrained_decoder_model_name_or_path
,
args
.
pretrained_decoder_model_name_or_path
,
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
),
prior_tokenizer
=
tokenizer
,
prior_tokenizer
=
tokenizer
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
,
torch_dtype
=
weight_dtype
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
# load lora weights
pipeline
.
prior_pipe
.
load_lora_weights
(
os
.
path
.
join
(
args
.
output_dir
,
"prior_lora"
))
# load lora weights
pipeline
.
prior_pipe
.
load_lora_weights
(
args
.
output_dir
,
weight_name
=
"pytorch_lora_weights.safetensors"
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
if
args
.
seed
is
None
:
if
args
.
seed
is
None
:
...
@@ -862,7 +904,7 @@ def main():
...
@@ -862,7 +904,7 @@ def main():
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
for
i
in
range
(
len
(
args
.
validation_prompts
)):
for
i
in
range
(
len
(
args
.
validation_prompts
)):
with
torch
.
autocast
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
image
=
pipeline
(
image
=
pipeline
(
args
.
validation_prompts
[
i
],
args
.
validation_prompts
[
i
],
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
prior_timesteps
=
DEFAULT_STAGE_C_TIMESTEPS
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment