"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c697c5ab57eb9bc98a38d8aba9e2fc8fb0b06b95"
Unverified Commit fc5e9066 authored by Parag Ekbote's avatar Parag Ekbote Committed by GitHub
Browse files

[train_text_to_image_sdxl]Add LANCZOS as default interpolation mode for image resizing (#11455)

* Add LANCZOS as default interplotation mode.

* update script

* Update as per code review.

* make style.
parent 8520d496
...@@ -470,6 +470,15 @@ def parse_args(input_args=None): ...@@ -470,6 +470,15 @@ def parse_args(input_args=None):
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -861,7 +870,10 @@ def main(args): ...@@ -861,7 +870,10 @@ def main(args):
) )
# Preprocessing the datasets. # Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0) train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment