Unverified Commit 4c660d16 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Stable Diffusion] Fix padding / truncation (#1226)

* [Stable Diffusion] Fix padding / truncation

* finish
parent 81715661
...@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np", return_tensors="np",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
......
...@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np", return_tensors="np",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
......
...@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np", return_tensors="np",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
......
...@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
......
...@@ -33,9 +33,10 @@ from diffusers import ( ...@@ -33,9 +33,10 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
VQModel, VQModel,
logging,
) )
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert image.shape == (1, 128, 128, 3) assert image.shape == (1, 128, 128, 3)
def test_stable_diffusion_long_prompt(self):
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
do_classifier_free_guidance = True
negative_prompt = None
num_images_per_prompt = 1
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
prompt = 25 * "@"
with CaptureLogger(logger) as cap_logger_3:
text_embeddings_3 = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
prompt = 100 * "@"
with CaptureLogger(logger) as cap_logger:
text_embeddings = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
negative_prompt = "Hello"
with CaptureLogger(logger) as cap_logger_2:
text_embeddings_2 = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
assert text_embeddings.shape[1] == 77
assert cap_logger.out == cap_logger_2.out
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
assert cap_logger.out.count("@") == 25
assert cap_logger_3.out == ""
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment