Unverified Commit 9c29e938 authored by Yijun Lee's avatar Yijun Lee Committed by GitHub
Browse files

Set LANCZOS as the default interpolation method for image resizing. (#11492)

* Set LANCZOS as the default interpolation method for image resizing.

* style: run make style and quality checks
parent 071807c8
...@@ -673,6 +673,15 @@ def parse_args(input_args=None): ...@@ -673,6 +673,15 @@ def parse_args(input_args=None):
default=False, default=False,
help="Cache the VAE latents", help="Cache the VAE latents",
) )
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -907,6 +916,10 @@ class DreamBoothDataset(Dataset): ...@@ -907,6 +916,10 @@ class DreamBoothDataset(Dataset):
self.num_instance_images = len(self.instance_images) self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images self._length = self.num_instance_images
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
if class_data_root is not None: if class_data_root is not None:
self.class_data_root = Path(class_data_root) self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_data_root.mkdir(parents=True, exist_ok=True)
...@@ -921,7 +934,7 @@ class DreamBoothDataset(Dataset): ...@@ -921,7 +934,7 @@ class DreamBoothDataset(Dataset):
self.image_transforms = transforms.Compose( self.image_transforms = transforms.Compose(
[ [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(size, interpolation=interpolation),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment