"git@developer.sourcefind.cn:change/sglang.git" did not exist on "9a00e6f453e764c0b286e2a62f652a1202c0bf9c"
Unverified Commit ed4efbd6 authored by RogerSinghChugh's avatar RogerSinghChugh Committed by GitHub
Browse files

Update training script for txt to img sdxl with lora supp with new interpolation. (#11496)

* Update training script for txt to img sdxl with lora supp with new interpolation.

* ran make style and make quality.
parent 9c29e938
...@@ -480,6 +480,15 @@ def parse_args(input_args=None): ...@@ -480,6 +480,15 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="debug loss for each image, if filenames are available in the dataset", help="debug loss for each image, if filenames are available in the dataset",
) )
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -913,8 +922,14 @@ def main(args): ...@@ -913,8 +922,14 @@ def main(args):
tokens_two = tokenize_prompt(tokenizer_two, captions) tokens_two = tokenize_prompt(tokenizer_two, captions)
return tokens_one, tokens_two return tokens_one, tokens_two
# Get the specified interpolation method from the args
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
# Raise an error if the interpolation method is invalid
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Preprocessing the datasets. # Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0) train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment