Unverified Commit 3d574b3b authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Fix a bug of flip in SDXL training script (#6547)

* Update train_text_to_image_sdxl.py

* Update train_text_to_image_lora_sdxl.py
parent 09903774
...@@ -836,6 +836,9 @@ def main(args): ...@@ -836,6 +836,9 @@ def main(args):
for image in images: for image in images:
original_sizes.append((image.height, image.width)) original_sizes.append((image.height, image.width))
image = train_resize(image) image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop: if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0))) y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
...@@ -843,10 +846,6 @@ def main(args): ...@@ -843,10 +846,6 @@ def main(args):
else: else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w) image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1) crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left) crop_top_lefts.append(crop_top_left)
image = train_transforms(image) image = train_transforms(image)
......
...@@ -839,6 +839,9 @@ def main(args): ...@@ -839,6 +839,9 @@ def main(args):
for image in images: for image in images:
original_sizes.append((image.height, image.width)) original_sizes.append((image.height, image.width))
image = train_resize(image) image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop: if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0))) y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
...@@ -846,10 +849,6 @@ def main(args): ...@@ -846,10 +849,6 @@ def main(args):
else: else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w) image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1) crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left) crop_top_lefts.append(crop_top_left)
image = train_transforms(image) image = train_transforms(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment