Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6427aa99
Unverified
Commit
6427aa99
authored
Jul 18, 2023
by
takuoko
Committed by
GitHub
Jul 18, 2023
Browse files
[Enhance] Add rank in dreambooth (#4112)
add rank in dreambooth
parent
8b18cd8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+5
-3
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
6427aa99
...
@@ -872,7 +872,9 @@ def main(args):
...
@@ -872,7 +872,9 @@ def main(args):
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
)
)
module
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
module
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
,
rank
=
args
.
rank
)
unet_lora_attn_procs
[
name
]
=
module
unet_lora_attn_procs
[
name
]
=
module
unet_lora_parameters
.
extend
(
module
.
parameters
())
unet_lora_parameters
.
extend
(
module
.
parameters
())
...
@@ -882,7 +884,7 @@ def main(args):
...
@@ -882,7 +884,7 @@ def main(args):
# So, instead, we monkey-patch the forward calls of its attention-blocks.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters
=
LoraLoaderMixin
.
_modify_text_encoder
(
text_encoder
,
dtype
=
torch
.
float32
)
text_lora_parameters
=
LoraLoaderMixin
.
_modify_text_encoder
(
text_encoder
,
dtype
=
torch
.
float32
,
rank
=
args
.
rank
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
...
@@ -1364,7 +1366,7 @@ def main(args):
...
@@ -1364,7 +1366,7 @@ def main(args):
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
# load attention processors
# load attention processors
pipeline
.
load_lora_weights
(
args
.
output_dir
)
pipeline
.
load_lora_weights
(
args
.
output_dir
,
weight_name
=
"pytorch_lora_weights.bin"
)
# run inference
# run inference
images
=
[]
images
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment