Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
d9075be4
Unverified
Commit
d9075be4
authored
Nov 27, 2023
by
YiYi Xu
Committed by
GitHub
Nov 27, 2023
Browse files
[load_textual_inversion]: allow multiple tokens (#5837)
Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
b135b6e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
2 deletions
+62
-2
src/diffusers/loaders/textual_inversion.py
src/diffusers/loaders/textual_inversion.py
+14
-2
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+48
-0
No files found.
src/diffusers/loaders/textual_inversion.py
View file @
d9075be4
...
...
@@ -189,7 +189,7 @@ class TextualInversionLoaderMixin:
f
" `
{
self
.
load_textual_inversion
.
__name__
}
`"
)
if
len
(
pretrained_model_name_or_paths
)
!=
len
(
tokens
):
if
len
(
pretrained_model_name_or_paths
)
>
1
and
len
(
pretrained_model_name_or_paths
)
!=
len
(
tokens
):
raise
ValueError
(
f
"You have passed a list of models of length
{
len
(
pretrained_model_name_or_paths
)
}
, and list of tokens of length
{
len
(
tokens
)
}
"
f
"Make sure both lists have the same length."
...
...
@@ -382,7 +382,9 @@ class TextualInversionLoaderMixin:
if
not
isinstance
(
pretrained_model_name_or_path
,
list
)
else
pretrained_model_name_or_path
)
tokens
=
len
(
pretrained_model_name_or_paths
)
*
[
token
]
if
(
isinstance
(
token
,
str
)
or
token
is
None
)
else
token
tokens
=
[
token
]
if
not
isinstance
(
token
,
list
)
else
token
if
tokens
[
0
]
is
None
:
tokens
=
tokens
*
len
(
pretrained_model_name_or_paths
)
# 3. Check inputs
self
.
_check_text_inv_inputs
(
tokenizer
,
text_encoder
,
pretrained_model_name_or_paths
,
tokens
)
...
...
@@ -390,6 +392,16 @@ class TextualInversionLoaderMixin:
# 4. Load state dicts of textual embeddings
state_dicts
=
load_textual_inversion_state_dicts
(
pretrained_model_name_or_paths
,
**
kwargs
)
# 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
if
len
(
tokens
)
>
1
and
len
(
state_dicts
)
==
1
:
if
isinstance
(
state_dicts
[
0
],
torch
.
Tensor
):
state_dicts
=
list
(
state_dicts
[
0
])
if
len
(
tokens
)
!=
len
(
state_dicts
):
raise
ValueError
(
f
"You have passed a state_dict contains
{
len
(
state_dicts
)
}
embeddings, and list of tokens of length
{
len
(
tokens
)
}
"
f
"Make sure both have the same length."
)
# 4. Retrieve tokens and embeddings
tokens
,
embeddings
=
self
.
_retrieve_tokens_and_embeddings
(
tokens
,
state_dicts
,
tokenizer
)
...
...
tests/pipelines/test_pipelines.py
View file @
d9075be4
...
...
@@ -792,6 +792,54 @@ class DownloadTests(unittest.TestCase):
out
=
pipe
(
prompt
,
num_inference_steps
=
1
,
output_type
=
"numpy"
).
images
assert
out
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_text_inversion_multi_tokens
(
self
):
pipe1
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe1
=
pipe1
.
to
(
torch_device
)
token1
,
token2
=
"<*>"
,
"<**>"
ten1
=
torch
.
ones
((
32
,))
ten2
=
torch
.
ones
((
32
,))
*
2
num_tokens
=
len
(
pipe1
.
tokenizer
)
pipe1
.
load_textual_inversion
(
ten1
,
token
=
token1
)
pipe1
.
load_textual_inversion
(
ten2
,
token
=
token2
)
emb1
=
pipe1
.
text_encoder
.
get_input_embeddings
().
weight
pipe2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe2
=
pipe2
.
to
(
torch_device
)
pipe2
.
load_textual_inversion
([
ten1
,
ten2
],
token
=
[
token1
,
token2
])
emb2
=
pipe2
.
text_encoder
.
get_input_embeddings
().
weight
pipe3
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe3
=
pipe3
.
to
(
torch_device
)
pipe3
.
load_textual_inversion
(
torch
.
stack
([
ten1
,
ten2
],
dim
=
0
),
token
=
[
token1
,
token2
])
emb3
=
pipe3
.
text_encoder
.
get_input_embeddings
().
weight
assert
len
(
pipe1
.
tokenizer
)
==
len
(
pipe2
.
tokenizer
)
==
len
(
pipe3
.
tokenizer
)
==
num_tokens
+
2
assert
(
pipe1
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
pipe2
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
pipe3
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
num_tokens
)
assert
(
pipe1
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
pipe2
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
pipe3
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
num_tokens
+
1
)
assert
emb1
[
num_tokens
].
sum
().
item
()
==
emb2
[
num_tokens
].
sum
().
item
()
==
emb3
[
num_tokens
].
sum
().
item
()
assert
(
emb1
[
num_tokens
+
1
].
sum
().
item
()
==
emb2
[
num_tokens
+
1
].
sum
().
item
()
==
emb3
[
num_tokens
+
1
].
sum
().
item
()
)
def
test_download_ignore_files
(
self
):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment