Unverified Commit 3d574b3b authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Fix a bug of flip in SDXL training script (#6547)

* Update train_text_to_image_sdxl.py

* Update train_text_to_image_lora_sdxl.py
parent 09903774
......@@ -836,6 +836,9 @@ def main(args):
for image in images:
original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
......@@ -843,10 +846,6 @@ def main(args):
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
......
......@@ -839,6 +839,9 @@ def main(args):
for image in images:
original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
......@@ -846,10 +849,6 @@ def main(args):
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment