Unverified Commit f7ebe569 authored by Yuta Hayashibe's avatar Yuta Hayashibe Committed by GitHub
Browse files

Warning for too long prompts in DiffusionPipelines (Resolve #447) (#472)

* Return encoded texts by DiffusionPipelines

* Updated README to show hot to use enoded_text_input

* Reverted examples in README.md

* Reverted all

* Warning for long prompts

* Fix bugs

* Formatted
parent 57b70c59
...@@ -10,10 +10,14 @@ from ...configuration_utils import FrozenDict ...@@ -10,10 +10,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionPipeline(DiffusionPipeline): class StableDiffusionPipeline(DiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
...@@ -188,14 +192,22 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -188,14 +192,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -203,7 +215,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -203,7 +215,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1] max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
) )
......
...@@ -12,10 +12,14 @@ from ...configuration_utils import FrozenDict ...@@ -12,10 +12,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
...@@ -216,14 +220,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -216,14 +220,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -231,7 +243,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -231,7 +243,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1] max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
) )
......
...@@ -254,14 +254,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -254,14 +254,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -269,7 +277,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -269,7 +277,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1] max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
) )
......
...@@ -8,9 +8,13 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer ...@@ -8,9 +8,13 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...onnx_utils import OnnxRuntimeModel from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__)
class StableDiffusionOnnxPipeline(DiffusionPipeline): class StableDiffusionOnnxPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel
...@@ -66,14 +70,22 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -66,14 +70,22 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt",
return_tensors="np",
) )
text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -81,7 +93,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -81,7 +93,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1] max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment