Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
04ddad48
Unverified
Commit
04ddad48
authored
Jul 07, 2023
by
Batuhan Taskaya
Committed by
GitHub
Jul 07, 2023
Browse files
Add 'rank' parameter to Dreambooth LoRA training script (#3945)
parent
03d829d5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+12
-2
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
04ddad48
...
@@ -436,6 +436,12 @@ def parse_args(input_args=None):
...
@@ -436,6 +436,12 @@ def parse_args(input_args=None):
default
=
None
,
default
=
None
,
help
=
"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`."
,
help
=
"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`."
,
)
)
parser
.
add_argument
(
"--rank"
,
type
=
int
,
default
=
4
,
help
=
(
"The dimension of the LoRA update matrices."
),
)
if
input_args
is
not
None
:
if
input_args
is
not
None
:
args
=
parser
.
parse_args
(
input_args
)
args
=
parser
.
parse_args
(
input_args
)
...
@@ -845,7 +851,9 @@ def main(args):
...
@@ -845,7 +851,9 @@ def main(args):
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
)
)
unet_lora_attn_procs
[
name
]
=
lora_attn_processor_class
(
unet_lora_attn_procs
[
name
]
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
,
rank
=
args
.
rank
,
)
)
unet
.
set_attn_processor
(
unet_lora_attn_procs
)
unet
.
set_attn_processor
(
unet_lora_attn_procs
)
...
@@ -860,7 +868,9 @@ def main(args):
...
@@ -860,7 +868,9 @@ def main(args):
for
name
,
module
in
text_encoder
.
named_modules
():
for
name
,
module
in
text_encoder
.
named_modules
():
if
name
.
endswith
(
TEXT_ENCODER_ATTN_MODULE
):
if
name
.
endswith
(
TEXT_ENCODER_ATTN_MODULE
):
text_lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
text_lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
hidden_size
=
module
.
out_proj
.
out_features
,
cross_attention_dim
=
None
hidden_size
=
module
.
out_proj
.
out_features
,
cross_attention_dim
=
None
,
rank
=
args
.
rank
,
)
)
text_encoder_lora_layers
=
AttnProcsLayers
(
text_lora_attn_procs
)
text_encoder_lora_layers
=
AttnProcsLayers
(
text_lora_attn_procs
)
temp_pipeline
=
DiffusionPipeline
.
from_pretrained
(
temp_pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment