Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cdf2ae8a
Unverified
Commit
cdf2ae8a
authored
Jun 29, 2023
by
takuoko
Committed by
GitHub
Jun 29, 2023
Browse files
[Enhance] Add LoRA rank args in train_text_to_image_lora (#3866)
* add rank args in lora finetune * del network_alpha
parent
49949f32
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
1 deletion
+11
-1
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+11
-1
No files found.
examples/text_to_image/train_text_to_image_lora.py
View file @
cdf2ae8a
...
...
@@ -343,6 +343,12 @@ def parse_args():
"--enable_xformers_memory_efficient_attention"
,
action
=
"store_true"
,
help
=
"Whether or not to use xformers."
)
parser
.
add_argument
(
"--noise_offset"
,
type
=
float
,
default
=
0
,
help
=
"The scale of noise offset."
)
parser
.
add_argument
(
"--rank"
,
type
=
int
,
default
=
4
,
help
=
(
"The dimension of the LoRA update matrices."
),
)
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
...
...
@@ -464,7 +470,11 @@ def main():
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
unet
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
LoRAAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
,
rank
=
args
.
rank
,
)
unet
.
set_attn_processor
(
lora_attn_procs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment