"vscode:/vscode.git/clone" did not exist on "f52fc3fcdf6cf22fd6a7c4db789f1a530a01db19"
Unverified Commit b3b04fef authored by tongyu's avatar tongyu Committed by GitHub
Browse files

[train_text_to_image] Better image interpolation in training scripts follow up (#11426)

* Update train_text_to_image.py

* update
parent 0e3f2713
...@@ -499,6 +499,15 @@ def parse_args(): ...@@ -499,6 +499,15 @@ def parse_args():
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
), ),
) )
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -787,10 +796,17 @@ def main(): ...@@ -787,10 +796,17 @@ def main():
) )
return inputs.input_ids return inputs.input_ids
# Preprocessing the datasets. # Get the specified interpolation method from the args
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
# Raise an error if the interpolation method is invalid
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Data preprocessing transformations
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(), transforms.ToTensor(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment