Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
160474ac
Unverified
Commit
160474ac
authored
Aug 01, 2023
by
Will Berman
Committed by
GitHub
Aug 01, 2023
Browse files
train dreambooth fix pre encode class prompt (#4395)
parent
c10861ee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
18 deletions
+18
-18
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+9
-9
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+9
-9
No files found.
examples/dreambooth/train_dreambooth.py
View file @
160474ac
...
@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
...
@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size
=
512
,
size
=
512
,
center_crop
=
False
,
center_crop
=
False
,
encoder_hidden_states
=
None
,
encoder_hidden_states
=
None
,
instance
_prompt_encoder_hidden_states
=
None
,
class
_prompt_encoder_hidden_states
=
None
,
tokenizer_max_length
=
None
,
tokenizer_max_length
=
None
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
center_crop
=
center_crop
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
instance
_prompt_encoder_hidden_states
=
instance
_prompt_encoder_hidden_states
self
.
class
_prompt_encoder_hidden_states
=
class
_prompt_encoder_hidden_states
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
instance_data_root
=
Path
(
instance_data_root
)
self
.
instance_data_root
=
Path
(
instance_data_root
)
...
@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
...
@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image
=
class_image
.
convert
(
"RGB"
)
class_image
=
class_image
.
convert
(
"RGB"
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
if
self
.
instance
_prompt_encoder_hidden_states
is
not
None
:
if
self
.
class
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
instance
_prompt_encoder_hidden_states
example
[
"class_prompt_ids"
]
=
self
.
class
_prompt_encoder_hidden_states
else
:
else
:
class_text_inputs
=
tokenize_prompt
(
class_text_inputs
=
tokenize_prompt
(
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
...
@@ -1027,10 +1027,10 @@ def main(args):
...
@@ -1027,10 +1027,10 @@ def main(args):
else
:
else
:
validation_prompt_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
if
args
.
instance
_prompt
is
not
None
:
if
args
.
class
_prompt
is
not
None
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance
_prompt
)
pre_computed_
class
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
class
_prompt
)
else
:
else
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
text_encoder
=
None
text_encoder
=
None
tokenizer
=
None
tokenizer
=
None
...
@@ -1041,7 +1041,7 @@ def main(args):
...
@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states
=
None
pre_computed_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_negative_prompt_embeds
=
None
validation_prompt_negative_prompt_embeds
=
None
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
# Dataset and DataLoaders creation:
# Dataset and DataLoaders creation:
train_dataset
=
DreamBoothDataset
(
train_dataset
=
DreamBoothDataset
(
...
@@ -1054,7 +1054,7 @@ def main(args):
...
@@ -1054,7 +1054,7 @@ def main(args):
size
=
args
.
resolution
,
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
center_crop
=
args
.
center_crop
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
instance
_prompt_encoder_hidden_states
=
pre_computed_
instance
_prompt_encoder_hidden_states
,
class
_prompt_encoder_hidden_states
=
pre_computed_
class
_prompt_encoder_hidden_states
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
)
)
...
...
examples/dreambooth/train_dreambooth_lora.py
View file @
160474ac
...
@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
...
@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size
=
512
,
size
=
512
,
center_crop
=
False
,
center_crop
=
False
,
encoder_hidden_states
=
None
,
encoder_hidden_states
=
None
,
instance
_prompt_encoder_hidden_states
=
None
,
class
_prompt_encoder_hidden_states
=
None
,
tokenizer_max_length
=
None
,
tokenizer_max_length
=
None
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
center_crop
=
center_crop
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
encoder_hidden_states
=
encoder_hidden_states
self
.
instance
_prompt_encoder_hidden_states
=
instance
_prompt_encoder_hidden_states
self
.
class
_prompt_encoder_hidden_states
=
class
_prompt_encoder_hidden_states
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
tokenizer_max_length
=
tokenizer_max_length
self
.
instance_data_root
=
Path
(
instance_data_root
)
self
.
instance_data_root
=
Path
(
instance_data_root
)
...
@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
...
@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image
=
class_image
.
convert
(
"RGB"
)
class_image
=
class_image
.
convert
(
"RGB"
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
if
self
.
instance
_prompt_encoder_hidden_states
is
not
None
:
if
self
.
class
_prompt_encoder_hidden_states
is
not
None
:
example
[
"class_prompt_ids"
]
=
self
.
instance
_prompt_encoder_hidden_states
example
[
"class_prompt_ids"
]
=
self
.
class
_prompt_encoder_hidden_states
else
:
else
:
class_text_inputs
=
tokenize_prompt
(
class_text_inputs
=
tokenize_prompt
(
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
self
.
tokenizer
,
self
.
class_prompt
,
tokenizer_max_length
=
self
.
tokenizer_max_length
...
@@ -993,10 +993,10 @@ def main(args):
...
@@ -993,10 +993,10 @@ def main(args):
else
:
else
:
validation_prompt_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
if
args
.
instance
_prompt
is
not
None
:
if
args
.
class
_prompt
is
not
None
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance_prompt
)
pre_computed_
class
_prompt_encoder_hidden_states
=
compute_text_embeddings
(
args
.
instance_prompt
)
else
:
else
:
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
text_encoder
=
None
text_encoder
=
None
tokenizer
=
None
tokenizer
=
None
...
@@ -1007,7 +1007,7 @@ def main(args):
...
@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states
=
None
pre_computed_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_encoder_hidden_states
=
None
validation_prompt_negative_prompt_embeds
=
None
validation_prompt_negative_prompt_embeds
=
None
pre_computed_
instance
_prompt_encoder_hidden_states
=
None
pre_computed_
class
_prompt_encoder_hidden_states
=
None
# Dataset and DataLoaders creation:
# Dataset and DataLoaders creation:
train_dataset
=
DreamBoothDataset
(
train_dataset
=
DreamBoothDataset
(
...
@@ -1020,7 +1020,7 @@ def main(args):
...
@@ -1020,7 +1020,7 @@ def main(args):
size
=
args
.
resolution
,
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
center_crop
=
args
.
center_crop
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
encoder_hidden_states
=
pre_computed_encoder_hidden_states
,
instance
_prompt_encoder_hidden_states
=
pre_computed_
instance
_prompt_encoder_hidden_states
,
class
_prompt_encoder_hidden_states
=
pre_computed_
class
_prompt_encoder_hidden_states
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
tokenizer_max_length
=
args
.
tokenizer_max_length
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment