You need to sign in or sign up before continuing.
Unverified Commit 160474ac authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

train dreambooth fix pre encode class prompt (#4395)

parent c10861ee
......@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
......@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
......@@ -1027,10 +1027,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
......@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
......@@ -1054,7 +1054,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)
......
......@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
......@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
......@@ -993,10 +993,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
......@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
......@@ -1020,7 +1020,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment