Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4c660d16
Unverified
Commit
4c660d16
authored
Nov 13, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 13, 2022
Browse files
[Stable Diffusion] Fix padding / truncation (#1226)
* [Stable Diffusion] Fix padding / truncation * finish
parent
81715661
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
25 deletions
+88
-25
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
...rs/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
...table_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
...table_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+4
-3
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+53
-1
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
View file @
4c660d16
...
...
@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
View file @
4c660d16
...
...
@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
View file @
4c660d16
...
...
@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
View file @
4c660d16
...
...
@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
4c660d16
...
...
@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
4c660d16
...
...
@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
4c660d16
...
...
@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
4c660d16
...
...
@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
4c660d16
...
...
@@ -33,9 +33,10 @@ from diffusers import (
UNet2DConditionModel
,
UNet2DModel
,
VQModel
,
logging
,
)
from
diffusers.utils
import
floats_tensor
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
diffusers.utils.testing_utils
import
CaptureLogger
,
require_torch_gpu
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
...test_pipelines_common
import
PipelineTesterMixin
...
...
@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert
image
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_stable_diffusion_long_prompt
(
self
):
unet
=
self
.
dummy_cond_unet
scheduler
=
LMSDiscreteScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
do_classifier_free_guidance
=
True
negative_prompt
=
None
num_images_per_prompt
=
1
logger
=
logging
.
get_logger
(
"diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion"
)
prompt
=
25
*
"@"
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
text_embeddings_3
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
prompt
=
100
*
"@"
with
CaptureLogger
(
logger
)
as
cap_logger
:
text_embeddings
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
negative_prompt
=
"Hello"
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
text_embeddings_2
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
assert
text_embeddings_3
.
shape
==
text_embeddings_2
.
shape
==
text_embeddings
.
shape
assert
text_embeddings
.
shape
[
1
]
==
77
assert
cap_logger
.
out
==
cap_logger_2
.
out
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
assert
cap_logger
.
out
.
count
(
"@"
)
==
25
assert
cap_logger_3
.
out
==
""
@
slow
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment