"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "304efbb6509ff931b050fcb2f2895a8d98b1d220"
Unverified Commit ce08cb72 authored by Ruizhe Wang's avatar Ruizhe Wang Committed by GitHub
Browse files

[Dreambooth] Editable number of class images (#2251)



* [Dreambooth] Editable number of class images

* 'class_num=None' bug fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 4aa68291
...@@ -454,6 +454,7 @@ class DreamBoothDataset(Dataset): ...@@ -454,6 +454,7 @@ class DreamBoothDataset(Dataset):
tokenizer, tokenizer,
class_data_root=None, class_data_root=None,
class_prompt=None, class_prompt=None,
class_num=None,
size=512, size=512,
center_crop=False, center_crop=False,
): ):
...@@ -474,7 +475,10 @@ class DreamBoothDataset(Dataset): ...@@ -474,7 +475,10 @@ class DreamBoothDataset(Dataset):
self.class_data_root = Path(class_data_root) self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir()) self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path) if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images) self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt self.class_prompt = class_prompt
else: else:
...@@ -814,6 +818,7 @@ def main(args): ...@@ -814,6 +818,7 @@ def main(args):
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt, class_prompt=args.class_prompt,
class_num=args.num_class_images,
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
......
...@@ -231,6 +231,7 @@ class DreamBoothDataset(Dataset): ...@@ -231,6 +231,7 @@ class DreamBoothDataset(Dataset):
tokenizer, tokenizer,
class_data_root=None, class_data_root=None,
class_prompt=None, class_prompt=None,
class_num=None,
size=512, size=512,
center_crop=False, center_crop=False,
): ):
...@@ -251,7 +252,10 @@ class DreamBoothDataset(Dataset): ...@@ -251,7 +252,10 @@ class DreamBoothDataset(Dataset):
self.class_data_root = Path(class_data_root) self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir()) self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path) if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images) self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt self.class_prompt = class_prompt
else: else:
...@@ -419,6 +423,7 @@ def main(): ...@@ -419,6 +423,7 @@ def main():
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt, class_prompt=args.class_prompt,
class_num=args.num_class_images,
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
......
...@@ -417,6 +417,7 @@ class DreamBoothDataset(Dataset): ...@@ -417,6 +417,7 @@ class DreamBoothDataset(Dataset):
tokenizer, tokenizer,
class_data_root=None, class_data_root=None,
class_prompt=None, class_prompt=None,
class_num=None,
size=512, size=512,
center_crop=False, center_crop=False,
): ):
...@@ -437,7 +438,10 @@ class DreamBoothDataset(Dataset): ...@@ -437,7 +438,10 @@ class DreamBoothDataset(Dataset):
self.class_data_root = Path(class_data_root) self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir()) self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path) if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images) self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt self.class_prompt = class_prompt
else: else:
...@@ -771,6 +775,7 @@ def main(args): ...@@ -771,6 +775,7 @@ def main(args):
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt, class_prompt=args.class_prompt,
class_num=args.num_class_images,
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment