"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d32e9391f99c837713197f57ad9772af0cc49308"
Unverified Commit 58431f10 authored by Youlun Peng's avatar Youlun Peng Committed by GitHub
Browse files

Set LANCZOS as the default interpolation for image resizing in ControlNet training (#11449)

Set LANCZOS as the default interpolation for image resizing
parent 4a9ab650
...@@ -639,6 +639,15 @@ def parse_args(input_args=None): ...@@ -639,6 +639,15 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="Enable model cpu offload and save memory.", help="Enable model cpu offload and save memory.",
) )
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator): ...@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
def prepare_train_dataset(dataset, accelerator): def prepare_train_dataset(dataset, accelerator):
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
image_transforms = transforms.Compose( image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution), transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
...@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator): ...@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose( conditioning_image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution), transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment