"...text-generation-inference.git" did not exist on "9f5c9a5e227e91abb6153d3b8ac82448cfc38056"
Unverified Commit 723dbdd3 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[Training] Better image interpolation in training scripts (#11206)



* initial

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent fbf61f46
......@@ -669,6 +669,16 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
......@@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset):
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment