Unverified Commit 160474ac authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

train dreambooth fix pre encode class prompt (#4395)

parent c10861ee
...@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset): ...@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size=512, size=512,
center_crop=False, center_crop=False,
encoder_hidden_states=None, encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None, class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None, tokenizer_max_length=None,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root) self.instance_data_root = Path(instance_data_root)
...@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset): ...@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB") class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None: if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else: else:
class_text_inputs = tokenize_prompt( class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
...@@ -1027,10 +1027,10 @@ def main(args): ...@@ -1027,10 +1027,10 @@ def main(args):
else: else:
validation_prompt_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None: if args.class_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
else: else:
pre_computed_instance_prompt_encoder_hidden_states = None pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None text_encoder = None
tokenizer = None tokenizer = None
...@@ -1041,7 +1041,7 @@ def main(args): ...@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states = None pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation: # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
...@@ -1054,7 +1054,7 @@ def main(args): ...@@ -1054,7 +1054,7 @@ def main(args):
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states, encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length, tokenizer_max_length=args.tokenizer_max_length,
) )
......
...@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset): ...@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size=512, size=512,
center_crop=False, center_crop=False,
encoder_hidden_states=None, encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None, class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None, tokenizer_max_length=None,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root) self.instance_data_root = Path(instance_data_root)
...@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset): ...@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB") class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None: if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else: else:
class_text_inputs = tokenize_prompt( class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
...@@ -993,10 +993,10 @@ def main(args): ...@@ -993,10 +993,10 @@ def main(args):
else: else:
validation_prompt_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None: if args.class_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
else: else:
pre_computed_instance_prompt_encoder_hidden_states = None pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None text_encoder = None
tokenizer = None tokenizer = None
...@@ -1007,7 +1007,7 @@ def main(args): ...@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states = None pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation: # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
...@@ -1020,7 +1020,7 @@ def main(args): ...@@ -1020,7 +1020,7 @@ def main(args):
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states, encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length, tokenizer_max_length=args.tokenizer_max_length,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment