Unverified Commit e2899989 authored by Adalberto's avatar Adalberto Committed by GitHub
Browse files

fix mask discrepancies in train_dreambooth_inpaint (#1529)

The mask and instance image were being cropped in different ways without --center_crop, causing the model to learn to ignore the mask in some cases. This PR fixes that and generate more consistent results.
parent 634be6e5
...@@ -295,10 +295,15 @@ class DreamBoothDataset(Dataset): ...@@ -295,10 +295,15 @@ class DreamBoothDataset(Dataset):
else: else:
self.class_data_root = None self.class_data_root = None
self.image_transforms = transforms.Compose( self.image_transforms_resize_and_crop = transforms.Compose(
[ [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
]
)
self.image_transforms = transforms.Compose(
[
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
...@@ -312,6 +317,7 @@ class DreamBoothDataset(Dataset): ...@@ -312,6 +317,7 @@ class DreamBoothDataset(Dataset):
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB": if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB") instance_image = instance_image.convert("RGB")
instance_image = self.image_transforms_resize_and_crop(instance_image)
example["PIL_images"] = instance_image example["PIL_images"] = instance_image
example["instance_images"] = self.image_transforms(instance_image) example["instance_images"] = self.image_transforms(instance_image)
...@@ -327,6 +333,7 @@ class DreamBoothDataset(Dataset): ...@@ -327,6 +333,7 @@ class DreamBoothDataset(Dataset):
class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB": if not class_image.mode == "RGB":
class_image = class_image.convert("RGB") class_image = class_image.convert("RGB")
class_image = self.image_transforms_resize_and_crop(class_image)
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
example["class_PIL_images"] = class_image example["class_PIL_images"] = class_image
example["class_prompt_ids"] = self.tokenizer( example["class_prompt_ids"] = self.tokenizer(
...@@ -513,12 +520,6 @@ def main(): ...@@ -513,12 +520,6 @@ def main():
) )
def collate_fn(examples): def collate_fn(examples):
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
)
input_ids = [example["instance_prompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples] pixel_values = [example["instance_images"] for example in examples]
...@@ -535,9 +536,6 @@ def main(): ...@@ -535,9 +536,6 @@ def main():
pil_image = example["PIL_images"] pil_image = example["PIL_images"]
# generate a random mask # generate a random mask
mask = random_mask(pil_image.size, 1, False) mask = random_mask(pil_image.size, 1, False)
# apply transforms
mask = image_transforms(mask)
pil_image = image_transforms(pil_image)
# prepare mask and masked image # prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
...@@ -548,9 +546,6 @@ def main(): ...@@ -548,9 +546,6 @@ def main():
for pil_image in pior_pil: for pil_image in pior_pil:
# generate a random mask # generate a random mask
mask = random_mask(pil_image.size, 1, False) mask = random_mask(pil_image.size, 1, False)
# apply transforms
mask = image_transforms(mask)
pil_image = image_transforms(pil_image)
# prepare mask and masked image # prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment