Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d9075be4
Unverified
Commit
d9075be4
authored
Nov 27, 2023
by
YiYi Xu
Committed by
GitHub
Nov 27, 2023
Browse files
[load_textual_inversion]: allow multiple tokens (#5837)
Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
b135b6e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
2 deletions
+62
-2
src/diffusers/loaders/textual_inversion.py
src/diffusers/loaders/textual_inversion.py
+14
-2
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+48
-0
No files found.
src/diffusers/loaders/textual_inversion.py
View file @
d9075be4
...
...
@@ -189,7 +189,7 @@ class TextualInversionLoaderMixin:
f
" `
{
self
.
load_textual_inversion
.
__name__
}
`"
)
if
len
(
pretrained_model_name_or_paths
)
!=
len
(
tokens
):
if
len
(
pretrained_model_name_or_paths
)
>
1
and
len
(
pretrained_model_name_or_paths
)
!=
len
(
tokens
):
raise
ValueError
(
f
"You have passed a list of models of length
{
len
(
pretrained_model_name_or_paths
)
}
, and list of tokens of length
{
len
(
tokens
)
}
"
f
"Make sure both lists have the same length."
...
...
@@ -382,7 +382,9 @@ class TextualInversionLoaderMixin:
if
not
isinstance
(
pretrained_model_name_or_path
,
list
)
else
pretrained_model_name_or_path
)
tokens
=
len
(
pretrained_model_name_or_paths
)
*
[
token
]
if
(
isinstance
(
token
,
str
)
or
token
is
None
)
else
token
tokens
=
[
token
]
if
not
isinstance
(
token
,
list
)
else
token
if
tokens
[
0
]
is
None
:
tokens
=
tokens
*
len
(
pretrained_model_name_or_paths
)
# 3. Check inputs
self
.
_check_text_inv_inputs
(
tokenizer
,
text_encoder
,
pretrained_model_name_or_paths
,
tokens
)
...
...
@@ -390,6 +392,16 @@ class TextualInversionLoaderMixin:
# 4. Load state dicts of textual embeddings
state_dicts
=
load_textual_inversion_state_dicts
(
pretrained_model_name_or_paths
,
**
kwargs
)
# 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
if
len
(
tokens
)
>
1
and
len
(
state_dicts
)
==
1
:
if
isinstance
(
state_dicts
[
0
],
torch
.
Tensor
):
state_dicts
=
list
(
state_dicts
[
0
])
if
len
(
tokens
)
!=
len
(
state_dicts
):
raise
ValueError
(
f
"You have passed a state_dict contains
{
len
(
state_dicts
)
}
embeddings, and list of tokens of length
{
len
(
tokens
)
}
"
f
"Make sure both have the same length."
)
# 4. Retrieve tokens and embeddings
tokens
,
embeddings
=
self
.
_retrieve_tokens_and_embeddings
(
tokens
,
state_dicts
,
tokenizer
)
...
...
tests/pipelines/test_pipelines.py
View file @
d9075be4
...
...
@@ -792,6 +792,54 @@ class DownloadTests(unittest.TestCase):
out
=
pipe
(
prompt
,
num_inference_steps
=
1
,
output_type
=
"numpy"
).
images
assert
out
.
shape
==
(
1
,
128
,
128
,
3
)
def
test_text_inversion_multi_tokens
(
self
):
pipe1
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe1
=
pipe1
.
to
(
torch_device
)
token1
,
token2
=
"<*>"
,
"<**>"
ten1
=
torch
.
ones
((
32
,))
ten2
=
torch
.
ones
((
32
,))
*
2
num_tokens
=
len
(
pipe1
.
tokenizer
)
pipe1
.
load_textual_inversion
(
ten1
,
token
=
token1
)
pipe1
.
load_textual_inversion
(
ten2
,
token
=
token2
)
emb1
=
pipe1
.
text_encoder
.
get_input_embeddings
().
weight
pipe2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe2
=
pipe2
.
to
(
torch_device
)
pipe2
.
load_textual_inversion
([
ten1
,
ten2
],
token
=
[
token1
,
token2
])
emb2
=
pipe2
.
text_encoder
.
get_input_embeddings
().
weight
pipe3
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
pipe3
=
pipe3
.
to
(
torch_device
)
pipe3
.
load_textual_inversion
(
torch
.
stack
([
ten1
,
ten2
],
dim
=
0
),
token
=
[
token1
,
token2
])
emb3
=
pipe3
.
text_encoder
.
get_input_embeddings
().
weight
assert
len
(
pipe1
.
tokenizer
)
==
len
(
pipe2
.
tokenizer
)
==
len
(
pipe3
.
tokenizer
)
==
num_tokens
+
2
assert
(
pipe1
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
pipe2
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
pipe3
.
tokenizer
.
convert_tokens_to_ids
(
token1
)
==
num_tokens
)
assert
(
pipe1
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
pipe2
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
pipe3
.
tokenizer
.
convert_tokens_to_ids
(
token2
)
==
num_tokens
+
1
)
assert
emb1
[
num_tokens
].
sum
().
item
()
==
emb2
[
num_tokens
].
sum
().
item
()
==
emb3
[
num_tokens
].
sum
().
item
()
assert
(
emb1
[
num_tokens
+
1
].
sum
().
item
()
==
emb2
[
num_tokens
+
1
].
sum
().
item
()
==
emb3
[
num_tokens
+
1
].
sum
().
item
()
)
def
test_download_ignore_files
(
self
):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment