Unverified Commit 7e6886f5 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

controlnet training resize inputs to multiple of 8 (#3135)

controlnet training center crop input images to multiple of 8

The pipeline code resizes inputs to multiples of 8.
Not doing this resizing in the training script is causing
the encoded image to have different height/width dimensions
than the encoded conditioning image (which uses a separate
encoder that's part of the controlnet model).

We resize and center crop the inputs to make sure they're the
same size (as well as all other images in the batch). We also
check that the initial resolution is a multiple of 8.
parent a4c91be7
...@@ -525,6 +525,11 @@ def parse_args(input_args=None): ...@@ -525,6 +525,11 @@ def parse_args(input_args=None):
" or the same number of `--validation_prompt`s and `--validation_image`s" " or the same number of `--validation_prompt`s and `--validation_image`s"
) )
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)
return args return args
...@@ -607,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator): ...@@ -607,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
image_transforms = transforms.Compose( image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
...@@ -615,6 +621,7 @@ def make_train_dataset(args, tokenizer, accelerator): ...@@ -615,6 +621,7 @@ def make_train_dataset(args, tokenizer, accelerator):
conditioning_image_transforms = transforms.Compose( conditioning_image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment