Unverified Commit 62c01d26 authored by Ernie Chu's avatar Ernie Chu Committed by GitHub
Browse files

Ensure validation image RGB not RGBA (#2945)



* ensure validation image RGB not RGBA

* ensure validation image RGB not RGBA

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f3e72e9e
...@@ -106,7 +106,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -106,7 +106,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
image_logs = [] image_logs = []
for validation_prompt, validation_image in zip(validation_prompts, validation_images): for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image) validation_image = Image.open(validation_image).convert('RGB')
images = [] images = []
......
...@@ -110,7 +110,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d ...@@ -110,7 +110,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = pipeline.prepare_text_inputs(prompts)
prompt_ids = shard(prompt_ids) prompt_ids = shard(prompt_ids)
validation_image = Image.open(validation_image) validation_image = Image.open(validation_image).convert('RGB')
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image) processed_image = shard(processed_image)
images = pipeline( images = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment