Unverified Commit 173e1b14 authored by Sam Gao's avatar Sam Gao Committed by GitHub
Browse files

[Examples] Uniform notations in train_flux_lora (#10011)



[Examples] uniform naming notations

since the in parameter `size` represents `args.resolution`, I thus replace the `args.resolution` inside DreamBoothData with `size`. And revise some notations such as `center_crop`.
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent e46e139f
...@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset): ...@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
def __init__( def __init__(
self, self,
args,
instance_data_root, instance_data_root,
instance_prompt, instance_prompt,
class_prompt, class_prompt,
...@@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset): ...@@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset):
class_num=None, class_num=None,
size=1024, size=1024,
repeats=1, repeats=1,
center_crop=False,
): ):
self.size = size self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt self.instance_prompt = instance_prompt
self.custom_instance_prompts = None self.custom_instance_prompts = None
...@@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset): ...@@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset):
# flip # flip
image = train_flip(image) image = train_flip(image)
if args.center_crop: if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0))) y1 = max(0, int(round((image.height - self.size) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - self.size) / 2.0)))
image = train_crop(image) image = train_crop(image)
else: else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
image = crop(image, y1, x1, h, w) image = crop(image, y1, x1, h, w)
image = train_transforms(image) image = train_transforms(image)
self.pixel_values.append(image) self.pixel_values.append(image)
...@@ -1827,6 +1826,7 @@ def main(args): ...@@ -1827,6 +1826,7 @@ def main(args):
# Dataset and DataLoaders creation: # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
args=args,
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
train_text_encoder_ti=args.train_text_encoder_ti, train_text_encoder_ti=args.train_text_encoder_ti,
...@@ -1836,7 +1836,6 @@ def main(args): ...@@ -1836,7 +1836,6 @@ def main(args):
class_num=args.num_class_images, class_num=args.num_class_images,
size=args.resolution, size=args.resolution,
repeats=args.repeats, repeats=args.repeats,
center_crop=args.center_crop,
) )
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment