"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "790a909b54091e27008eceb6caf06a9cbd800476"
Unverified Commit bdbaea8f authored by Bios's avatar Bios Committed by GitHub
Browse files

update StableDiffusion3Img2ImgPipeline.add image size validation (#10166)



* update StableDiffusion3Img2ImgPipeline.add image size validation

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent e8b65bff
...@@ -549,6 +549,8 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -549,6 +549,8 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None, negative_prompt_2=None,
...@@ -560,6 +562,15 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -560,6 +562,15 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if (
height % (self.vae_scale_factor * self.patch_size) != 0
or width % (self.vae_scale_factor * self.patch_size) != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
)
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -730,6 +741,8 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -730,6 +741,8 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
strength: float = 0.6, strength: float = 0.6,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -860,11 +873,15 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -860,11 +873,15 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
...@@ -933,7 +950,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -933,7 +950,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Preprocess image # 3. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image, height=height, width=width)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
......
...@@ -218,6 +218,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -218,6 +218,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
) )
self.tokenizer_max_length = self.tokenizer.model_max_length self.tokenizer_max_length = self.tokenizer.model_max_length
self.default_sample_size = self.transformer.config.sample_size self.default_sample_size = self.transformer.config.sample_size
self.patch_size = (
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -531,6 +534,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -531,6 +534,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None, negative_prompt_2=None,
...@@ -542,6 +547,15 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -542,6 +547,15 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if (
height % (self.vae_scale_factor * self.patch_size) != 0
or width % (self.vae_scale_factor * self.patch_size) != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
)
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -710,6 +724,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -710,6 +724,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
strength: float = 0.6, strength: float = 0.6,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -824,12 +840,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -824,12 +840,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
...@@ -890,7 +910,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -890,7 +910,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Preprocess image # 3. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image, height=height, width=width)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
......
...@@ -224,6 +224,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -224,6 +224,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
) )
self.tokenizer_max_length = self.tokenizer.model_max_length self.tokenizer_max_length = self.tokenizer.model_max_length
self.default_sample_size = self.transformer.config.sample_size self.default_sample_size = self.transformer.config.sample_size
self.patch_size = (
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -538,6 +541,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -538,6 +541,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None, negative_prompt_2=None,
...@@ -549,6 +554,15 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -549,6 +554,15 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if (
height % (self.vae_scale_factor * self.patch_size) != 0
or width % (self.vae_scale_factor * self.patch_size) != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
)
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -953,6 +967,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -953,6 +967,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt, prompt,
prompt_2, prompt_2,
prompt_3, prompt_3,
height,
width,
strength, strength,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment