Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4c660d16
Unverified
Commit
4c660d16
authored
Nov 13, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 13, 2022
Browse files
[Stable Diffusion] Fix padding / truncation (#1226)
* [Stable Diffusion] Fix padding / truncation * finish
parent
81715661
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
25 deletions
+88
-25
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
...rs/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
...table_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
...table_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+5
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+4
-3
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+53
-1
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
View file @
4c660d16
...
@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline):
...
@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
View file @
4c660d16
...
@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
return_tensors
=
"np"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
View file @
4c660d16
...
@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
return_tensors
=
"np"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
View file @
4c660d16
...
@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
return_tensors
=
"np"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"np"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
np
.
array_equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
self
.
text_encoder
(
input_ids
=
text_input_ids
.
astype
(
np
.
int32
))[
0
]
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
text_embeddings
=
np
.
repeat
(
text_embeddings
,
num_images_per_prompt
,
axis
=
0
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
4c660d16
...
@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
4c660d16
...
@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
4c660d16
...
@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
4c660d16
...
@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
...
@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
prompt
,
prompt
,
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
)
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input
_ids
[:,
self
.
tokenizer
.
model_max_length
:
])
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated
_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
4c660d16
...
@@ -33,9 +33,10 @@ from diffusers import (
...
@@ -33,9 +33,10 @@ from diffusers import (
UNet2DConditionModel
,
UNet2DConditionModel
,
UNet2DModel
,
UNet2DModel
,
VQModel
,
VQModel
,
logging
,
)
)
from
diffusers.utils
import
floats_tensor
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils
import
floats_tensor
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
diffusers.utils.testing_utils
import
CaptureLogger
,
require_torch_gpu
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
...test_pipelines_common
import
PipelineTesterMixin
from
...test_pipelines_common
import
PipelineTesterMixin
...
@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert
image
.
shape
==
(
1
,
128
,
128
,
3
)
assert
image
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_stable_diffusion_long_prompt
(
self
):
unet
=
self
.
dummy_cond_unet
scheduler
=
LMSDiscreteScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
do_classifier_free_guidance
=
True
negative_prompt
=
None
num_images_per_prompt
=
1
logger
=
logging
.
get_logger
(
"diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion"
)
prompt
=
25
*
"@"
with
CaptureLogger
(
logger
)
as
cap_logger_3
:
text_embeddings_3
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
prompt
=
100
*
"@"
with
CaptureLogger
(
logger
)
as
cap_logger
:
text_embeddings
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
negative_prompt
=
"Hello"
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
text_embeddings_2
=
sd_pipe
.
_encode_prompt
(
prompt
,
torch_device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
assert
text_embeddings_3
.
shape
==
text_embeddings_2
.
shape
==
text_embeddings
.
shape
assert
text_embeddings
.
shape
[
1
]
==
77
assert
cap_logger
.
out
==
cap_logger_2
.
out
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
assert
cap_logger
.
out
.
count
(
"@"
)
==
25
assert
cap_logger_3
.
out
==
""
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment