Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
160474ac
You need to sign in or sign up before continuing.
Unverified
Commit
160474ac
authored
Aug 01, 2023
by
Will Berman
Committed by
GitHub
Aug 01, 2023
Browse files
train dreambooth fix pre encode class prompt (#4395)
parent
c10861ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
18 deletions
+18
-18
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+9
-9
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+9
-9
No files found.
examples/dreambooth/train_dreambooth.py
View file @
160474ac
...
...
@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size
=
512
,
center_crop
=
False
,
encoder_hidden_states
=
None
,
instance
_prompt_encoder_hidden_states
=
None
,
class
_prompt_encoder_hidden_states
=
None
,
tokenizer_max_length
=
None
,
):
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
tokenizer
=
tokenizer
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
instance
_prompt_encoder_hidden_states
=
instance
_prompt_encoder_hidden_states
self
.
class
_prompt_encoder_hidden_states
=
class
_prompt_encoder_hidden_states
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
instance_data_root
=
Path
(
instance_data_root
)
...
...
@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image
=
class_image
.
convert
(
"RGB"
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
if
self
.
instance
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
instance
_prompt_encoder_hidden_states
if
self
.
class
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
class
_prompt_encoder_hidden_states
else
:
class_text_inputs
=
tokenize_prompt
(
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
...
...
@@ -1027,10 +1027,10 @@ def main(args):
else
:
validation_prompt_encoder_hidden_states
=
None
if
args
.
instance
_prompt
is
not
None
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance
_prompt
)
if
args
.
class
_prompt
is
not
None
:
pre_computed_
class
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
class
_prompt
)
else
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
text_encoder
=
None
tokenizer
=
None
...
...
@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_negative_prompt_embeds
=
None
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
# Dataset and DataLoaders creation:
train_dataset
=
DreamBoothDataset
(
...
...
@@ -1054,7 +1054,7 @@ def main(args):
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
instance
_prompt_encoder_hidden_states
=
pre_computed_
instance
_prompt_encoder_hidden_states
,
class
_prompt_encoder_hidden_states
=
pre_computed_
class
_prompt_encoder_hidden_states
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
)
...
...
examples/dreambooth/train_dreambooth_lora.py
View file @
160474ac
...
...
@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size
=
512
,
center_crop
=
False
,
encoder_hidden_states
=
None
,
instance
_prompt_encoder_hidden_states
=
None
,
class
_prompt_encoder_hidden_states
=
None
,
tokenizer_max_length
=
None
,
):
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
tokenizer
=
tokenizer
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
instance
_prompt_encoder_hidden_states
=
instance
_prompt_encoder_hidden_states
self
.
class
_prompt_encoder_hidden_states
=
class
_prompt_encoder_hidden_states
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
instance_data_root
=
Path
(
instance_data_root
)
...
...
@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image
=
class_image
.
convert
(
"RGB"
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
if
self
.
instance
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
instance
_prompt_encoder_hidden_states
if
self
.
class
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
class
_prompt_encoder_hidden_states
else
:
class_text_inputs
=
tokenize_prompt
(
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
...
...
@@ -993,10 +993,10 @@ def main(args):
else
:
validation_prompt_encoder_hidden_states
=
None
if
args
.
instance
_prompt
is
not
None
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance_prompt
)
if
args
.
class
_prompt
is
not
None
:
pre_computed_
class
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance_prompt
)
else
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
text_encoder
=
None
tokenizer
=
None
...
...
@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_negative_prompt_embeds
=
None
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
# Dataset and DataLoaders creation:
train_dataset
=
DreamBoothDataset
(
...
...
@@ -1020,7 +1020,7 @@ def main(args):
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
instance
_prompt_encoder_hidden_states
=
pre_computed_
instance
_prompt_encoder_hidden_states
,
class
_prompt_encoder_hidden_states
=
pre_computed_
class
_prompt_encoder_hidden_states
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment