Unverified Commit f3209b5b authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[SD3 Inference] T5 Token limit (#8506)



* max_sequence_length for the T5

* updated img2img

* apply suggestions

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 96399c3e
...@@ -205,6 +205,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -205,6 +205,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
...@@ -224,7 +225,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -224,7 +225,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
text_inputs = self.tokenizer_3( text_inputs = self.tokenizer_3(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer_max_length, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
...@@ -235,8 +236,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -235,8 +236,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} tokens: {removed_text}" f" {max_sequence_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
...@@ -323,6 +324,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -323,6 +324,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
): ):
r""" r"""
...@@ -403,6 +405,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -403,6 +405,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
t5_prompt_embed = self._get_t5_prompt_embeds( t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3, prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device, device=device,
) )
...@@ -456,7 +459,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -456,7 +459,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
t5_negative_prompt_embed = self._get_t5_prompt_embeds( t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
) )
negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds = torch.nn.functional.pad(
...@@ -486,6 +492,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -486,6 +492,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -557,6 +564,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -557,6 +564,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
...@@ -643,6 +653,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -643,6 +653,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -726,6 +737,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -726,6 +737,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -753,6 +765,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -753,6 +765,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
...@@ -790,6 +803,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -790,6 +803,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
device=device, device=device,
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
) )
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
......
...@@ -220,6 +220,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -220,6 +220,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
...@@ -239,7 +240,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -239,7 +240,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
text_inputs = self.tokenizer_3( text_inputs = self.tokenizer_3(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer_max_length, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
...@@ -250,8 +251,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -250,8 +251,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} tokens: {removed_text}" f" {max_sequence_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
...@@ -340,6 +341,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -340,6 +341,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
): ):
r""" r"""
...@@ -420,6 +422,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -420,6 +422,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
t5_prompt_embed = self._get_t5_prompt_embeds( t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3, prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device, device=device,
) )
...@@ -473,7 +476,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -473,7 +476,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
t5_negative_prompt_embed = self._get_t5_prompt_embeds( t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
) )
negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds = torch.nn.functional.pad(
...@@ -502,6 +508,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -502,6 +508,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -573,6 +580,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -573,6 +580,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps) init_timestep = min(num_inference_steps * strength, num_inference_steps)
...@@ -684,6 +694,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -684,6 +694,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -763,6 +774,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -763,6 +774,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -786,6 +798,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -786,6 +798,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
...@@ -822,6 +835,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -822,6 +835,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
device=device, device=device,
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
) )
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment