Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
3ebd2d1f
Unverified
Commit
3ebd2d1f
authored
May 17, 2023
by
Patrick von Platen
Committed by
GitHub
May 17, 2023
Browse files
Make dreambooth lora more robust to orig unet (#3462)
* Make dreambooth lora more robust to orig unet * up
parent
15f1bab1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
13 deletions
+5
-13
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+5
-13
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
3ebd2d1f
...
@@ -31,7 +31,7 @@ import transformers
...
@@ -31,7 +31,7 @@ import transformers
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.logging
import
get_logger
from
accelerate.logging
import
get_logger
from
accelerate.utils
import
ProjectConfiguration
,
set_seed
from
accelerate.utils
import
ProjectConfiguration
,
set_seed
from
huggingface_hub
import
create_repo
,
model_info
,
upload_folder
from
huggingface_hub
import
create_repo
,
upload_folder
from
packaging
import
version
from
packaging
import
version
from
PIL
import
Image
from
PIL
import
Image
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
...
@@ -589,16 +589,6 @@ class PromptDataset(Dataset):
...
@@ -589,16 +589,6 @@ class PromptDataset(Dataset):
return
example
return
example
def
model_has_vae
(
args
):
config_file_name
=
os
.
path
.
join
(
"vae"
,
AutoencoderKL
.
config_name
)
if
os
.
path
.
isdir
(
args
.
pretrained_model_name_or_path
):
config_file_name
=
os
.
path
.
join
(
args
.
pretrained_model_name_or_path
,
config_file_name
)
return
os
.
path
.
isfile
(
config_file_name
)
else
:
files_in_repo
=
model_info
(
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
).
siblings
return
any
(
file
.
rfilename
==
config_file_name
for
file
in
files_in_repo
)
def
tokenize_prompt
(
tokenizer
,
prompt
,
tokenizer_max_length
=
None
):
def
tokenize_prompt
(
tokenizer
,
prompt
,
tokenizer_max_length
=
None
):
if
tokenizer_max_length
is
not
None
:
if
tokenizer_max_length
is
not
None
:
max_length
=
tokenizer_max_length
max_length
=
tokenizer_max_length
...
@@ -753,11 +743,13 @@ def main(args):
...
@@ -753,11 +743,13 @@ def main(args):
text_encoder
=
text_encoder_cls
.
from_pretrained
(
text_encoder
=
text_encoder_cls
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
)
)
if
model_has_vae
(
args
)
:
try
:
vae
=
AutoencoderKL
.
from_pretrained
(
vae
=
AutoencoderKL
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
revision
=
args
.
revision
)
)
else
:
except
OSError
:
# IF does not have a VAE so let's just set it to None
# We don't have to error out here
vae
=
None
vae
=
None
unet
=
UNet2DConditionModel
.
from_pretrained
(
unet
=
UNet2DConditionModel
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment