"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "269109dbfbbdbe2800535239b881e96e1828a0ef"
Unverified Commit ec3d5828 authored by Yash's avatar Yash Committed by GitHub
Browse files

[train_dreambooth_lora_flux_advanced] Add LANCZOS as the default interpolation...

[train_dreambooth_lora_flux_advanced] Add LANCZOS as the default interpolation mode for image resizing (#11472)

* [train_controlnet_sdxl] Add LANCZOS as the default interpolation mode for image resizing

* [train_dreambooth_lora_flux_advanced] Add LANCZOS as the default interpolation mode for image resizing
parent ed6cf525
...@@ -770,6 +770,15 @@ def parse_args(input_args=None): ...@@ -770,6 +770,15 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -1034,7 +1043,10 @@ class DreamBoothDataset(Dataset): ...@@ -1034,7 +1043,10 @@ class DreamBoothDataset(Dataset):
self.instance_images.extend(itertools.repeat(img, repeats)) self.instance_images.extend(itertools.repeat(img, repeats))
self.pixel_values = [] self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0) train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
...@@ -1078,7 +1090,7 @@ class DreamBoothDataset(Dataset): ...@@ -1078,7 +1090,7 @@ class DreamBoothDataset(Dataset):
self.image_transforms = transforms.Compose( self.image_transforms = transforms.Compose(
[ [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(size, interpolation=interpolation),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment