Unverified Commit 160474ac authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

train dreambooth fix pre encode class prompt (#4395)

parent c10861ee
......@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
......@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
......@@ -1027,10 +1027,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
......@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
......@@ -1054,7 +1054,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)
......
......@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
......@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
......@@ -993,10 +993,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
......@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
......@@ -1020,7 +1020,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment