Unverified Commit ed4efbd6 authored by RogerSinghChugh's avatar RogerSinghChugh Committed by GitHub
Browse files

Update training script for txt to img sdxl with lora supp with new interpolation. (#11496)

* Update training script for txt to img sdxl with lora supp with new interpolation.

* ran make style and make quality.
parent 9c29e938
...@@ -480,6 +480,15 @@ def parse_args(input_args=None): ...@@ -480,6 +480,15 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="debug loss for each image, if filenames are available in the dataset", help="debug loss for each image, if filenames are available in the dataset",
) )
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -913,8 +922,14 @@ def main(args): ...@@ -913,8 +922,14 @@ def main(args):
tokens_two = tokenize_prompt(tokenizer_two, captions) tokens_two = tokenize_prompt(tokenizer_two, captions)
return tokens_one, tokens_two return tokens_one, tokens_two
# Get the specified interpolation method from the args
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
# Raise an error if the interpolation method is invalid
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Preprocessing the datasets. # Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0) train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment