Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
0f55c17e
Commit
0f55c17e
authored
Dec 01, 2023
by
Patrick von Platen
Browse files
fix style
parent
5058d27f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
182 additions
and
104 deletions
+182
-104
examples/community/regional_prompting_stable_diffusion.py
examples/community/regional_prompting_stable_diffusion.py
+182
-104
No files found.
examples/community/regional_prompting_stable_diffusion.py
View file @
0f55c17e
import
torchvision.transforms.functional
as
FF
import
math
import
torch
import
torchvision
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
torch
import
torchvision.transforms.functional
as
FF
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers
import
StableDiffusionPipeline
from
diffusers
import
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.utils
import
USE_PEFT_BACKEND
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.schedulers
import
KarrasDiffusionSchedulers
from
diffusers.schedulers
import
KarrasDiffusionSchedulers
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers.utils
import
USE_PEFT_BACKEND
try
:
try
:
from
compel
import
Compel
from
compel
import
Compel
except
:
except
ImportError
:
Compel
=
None
Compel
=
None
KCOMM
=
"ADDCOMM"
KCOMM
=
"ADDCOMM"
KBRK
=
"BREAK"
KBRK
=
"BREAK"
class
RegionalPromptingStableDiffusionPipeline
(
StableDiffusionPipeline
):
class
RegionalPromptingStableDiffusionPipeline
(
StableDiffusionPipeline
):
r
"""
r
"""
Args for Regional Prompting Pipeline:
Args for Regional Prompting Pipeline:
rp_args:dict
rp_args:dict
Required
Required
rp_args["mode"]: cols, rows, prompt, prompt-ex
rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode
for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode
for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
rp_args["save_mask"]: True/False (save masks in prompt mode)
...
@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
vae
:
AutoencoderKL
,
vae
:
AutoencoderKL
,
...
@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor
:
CLIPFeatureExtractor
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
requires_safety_checker
:
bool
=
True
,
):
):
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
self
.
register_modules
(
self
.
register_modules
(
vae
=
vae
,
vae
=
vae
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
...
@@ -93,50 +100,56 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -93,50 +100,56 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
):
):
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
# noqa: E721
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
if
negative_prompt
is
None
:
if
negative_prompt
is
None
:
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
# noqa: E721
device
=
self
.
_execution_device
device
=
self
.
_execution_device
regions
=
0
regions
=
0
self
.
power
=
int
(
rp_args
[
"power"
])
if
"power"
in
rp_args
else
1
self
.
power
=
int
(
rp_args
[
"power"
])
if
"power"
in
rp_args
else
1
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
# noqa: E721
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
# noqa: E721
self
.
batch
=
batch
=
num_images_per_prompt
*
len
(
prompts
)
self
.
batch
=
batch
=
num_images_per_prompt
*
len
(
prompts
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
cn
=
len
(
all_prompts_cn
)
==
len
(
all_n_prompts_cn
)
cn
=
len
(
all_prompts_cn
)
==
len
(
all_n_prompts_cn
)
if
Compel
:
if
Compel
:
compel
=
Compel
(
tokenizer
=
self
.
tokenizer
,
text_encoder
=
self
.
text_encoder
)
compel
=
Compel
(
tokenizer
=
self
.
tokenizer
,
text_encoder
=
self
.
text_encoder
)
def
getcompelembs
(
prps
):
def
getcompelembs
(
prps
):
embl
=
[]
embl
=
[]
for
prp
in
prps
:
for
prp
in
prps
:
embl
.
append
(
compel
.
build_conditioning_tensor
(
prp
))
embl
.
append
(
compel
.
build_conditioning_tensor
(
prp
))
return
torch
.
cat
(
embl
)
return
torch
.
cat
(
embl
)
conds
=
getcompelembs
(
all_prompts_cn
)
conds
=
getcompelembs
(
all_prompts_cn
)
unconds
=
getcompelembs
(
all_n_prompts_cn
)
if
cn
else
getcompelembs
(
n_prompts
)
unconds
=
getcompelembs
(
all_n_prompts_cn
)
if
cn
else
getcompelembs
(
n_prompts
)
embs
=
getcompelembs
(
prompts
)
embs
=
getcompelembs
(
prompts
)
n_embs
=
getcompelembs
(
n_prompts
)
n_embs
=
getcompelembs
(
n_prompts
)
prompt
=
negative_prompt
=
None
prompt
=
negative_prompt
=
None
else
:
else
:
conds
=
self
.
encode_prompt
(
prompts
,
device
,
1
,
True
)[
0
]
conds
=
self
.
encode_prompt
(
prompts
,
device
,
1
,
True
)[
0
]
unconds
=
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
unconds
=
(
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
)
embs
=
n_embs
=
None
embs
=
n_embs
=
None
if
not
active
:
if
not
active
:
pcallback
=
None
pcallback
=
None
mode
=
None
mode
=
None
else
:
else
:
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
mode
=
"COL"
if
"COL"
in
rp_args
[
"mode"
].
upper
()
else
"ROW"
mode
=
"COL"
if
"COL"
in
rp_args
[
"mode"
].
upper
()
else
"ROW"
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
elif
"PRO"
in
rp_args
[
"mode"
].
upper
():
elif
"PRO"
in
rp_args
[
"mode"
].
upper
():
regions
=
len
(
all_prompts_p
[
0
])
regions
=
len
(
all_prompts_p
[
0
])
mode
=
"PROMPT"
mode
=
"PROMPT"
...
@@ -144,14 +157,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -144,14 +157,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self
.
ex
=
"EX"
in
rp_args
[
"mode"
].
upper
()
self
.
ex
=
"EX"
in
rp_args
[
"mode"
].
upper
()
self
.
target_tokens
=
target_tokens
=
tokendealer
(
self
,
all_prompts_p
)
self
.
target_tokens
=
target_tokens
=
tokendealer
(
self
,
all_prompts_p
)
thresholds
=
[
float
(
x
)
for
x
in
rp_args
[
"th"
].
split
(
","
)]
thresholds
=
[
float
(
x
)
for
x
in
rp_args
[
"th"
].
split
(
","
)]
orig_hw
=
(
height
,
width
)
orig_hw
=
(
height
,
width
)
revers
=
True
revers
=
True
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
if
"PRO"
in
mode
:
# in Prompt mode, make masks from sum of attension maps
if
"PRO"
in
mode
:
# in Prompt mode, make masks from sum of attension maps
self
.
step
=
step
self
.
step
=
step
if
len
(
self
.
attnmaps_sizes
)
>
3
:
if
len
(
self
.
attnmaps_sizes
)
>
3
:
self
.
history
[
step
]
=
self
.
attnmaps
.
copy
()
self
.
history
[
step
]
=
self
.
attnmaps
.
copy
()
for
hw
in
self
.
attnmaps_sizes
:
for
hw
in
self
.
attnmaps_sizes
:
...
@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
allmasks
[
b
::
batch
]
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
allmasks
[
b
::
batch
]]
allmasks
[
b
::
batch
]
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
allmasks
[
b
::
batch
]]
allmasks
.
append
(
mask
)
allmasks
.
append
(
mask
)
basemasks
[
b
]
=
mask
if
basemasks
[
b
]
is
None
else
basemasks
[
b
]
+
mask
basemasks
[
b
]
=
mask
if
basemasks
[
b
]
is
None
else
basemasks
[
b
]
+
mask
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
basemasks
]
basemasks
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
basemasks
]
allmasks
=
basemasks
+
allmasks
allmasks
=
basemasks
+
allmasks
...
@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return
latents
return
latents
def
hook_forward
(
module
):
def
hook_forward
(
module
):
#diffusers==0.23.2
#
diffusers==0.23.2
def
forward
(
def
forward
(
hidden_states
:
torch
.
FloatTensor
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
attn
=
module
attn
=
module
xshape
=
hidden_states
.
shape
xshape
=
hidden_states
.
shape
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
if
revers
:
if
revers
:
nx
,
px
=
hidden_states
.
chunk
(
2
)
nx
,
px
=
hidden_states
.
chunk
(
2
)
else
:
else
:
px
,
nx
=
hidden_states
.
chunk
(
2
)
px
,
nx
=
hidden_states
.
chunk
(
2
)
if
cn
:
if
cn
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
else
:
else
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
residual
=
hidden_states
residual
=
hidden_states
...
@@ -247,12 +259,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -247,12 +259,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
scaled_dot_product_attention
(
hidden_states
=
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
,
)
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
# linear proj
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
,
*
args
)
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
,
*
args
)
# dropout
# dropout
...
@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center
=
reshaped
.
shape
[
0
]
//
2
center
=
reshaped
.
shape
[
0
]
//
2
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
for
out
in
outs
:
for
out
in
outs
:
c
=
0
c
=
0
for
i
,
ocell
in
enumerate
(
ocells
):
for
i
,
ocell
in
enumerate
(
ocells
):
for
icell
in
icells
[
i
]:
for
icell
in
icells
[
i
]:
if
"ROW"
in
mode
:
if
"ROW"
in
mode
:
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
else
:
else
:
out
[
0
:
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
c
+=
1
c
+=
1
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
hidden_states
.
reshape
(
xshape
)
hidden_states
=
hidden_states
.
reshape
(
xshape
)
#### Regional Prompting Prompt mode
#### Regional Prompting Prompt mode
...
@@ -291,17 +330,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -291,17 +330,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center
=
reshaped
.
shape
[
0
]
//
2
center
=
reshaped
.
shape
[
0
]
//
2
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
def
mask
(
input
):
def
mask
(
input
):
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
for
r
in
range
(
1
,
regions
):
for
r
in
range
(
1
,
regions
):
out
[
b
]
=
out
[
b
]
+
out
[
r
*
batch
+
b
]
out
[
b
]
=
out
[
b
]
+
out
[
r
*
batch
+
b
]
return
out
return
out
px
,
nx
=
(
mask
(
px
),
mask
(
nx
))
if
cn
else
(
mask
(
px
),
nx
)
px
,
nx
=
(
mask
(
px
),
mask
(
nx
))
if
cn
else
(
mask
(
px
),
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
return
hidden_states
return
hidden_states
return
forward
return
forward
...
@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
=
latents
,
latents
=
latents
,
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback_on_step_end
=
pcallback
callback_on_step_end
=
pcallback
,
)
)
if
"save_mask"
in
rp_args
:
if
"save_mask"
in
rp_args
:
...
@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else
:
else
:
save_mask
=
False
save_mask
=
False
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
return
output
return
output
### Make prompt list for each regions
### Make prompt list for each regions
def
promptsmaker
(
prompts
,
batch
):
def
promptsmaker
(
prompts
,
batch
):
out_p
=
[]
out_p
=
[]
plen
=
len
(
prompts
)
plen
=
len
(
prompts
)
for
prompt
in
prompts
:
for
prompt
in
prompts
:
...
@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
...
@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
add
=
add
+
" "
add
=
add
+
" "
prompts
=
prompt
.
split
(
KBRK
)
prompts
=
prompt
.
split
(
KBRK
)
out_p
.
append
([
add
+
p
for
p
in
prompts
])
out_p
.
append
([
add
+
p
for
p
in
prompts
])
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
for
p
,
prs
in
enumerate
(
out_p
):
# inputs prompts
for
p
,
prs
in
enumerate
(
out_p
):
# inputs prompts
for
r
,
pr
in
enumerate
(
prs
):
# prompts for regions
for
r
,
pr
in
enumerate
(
prs
):
# prompts for regions
start
=
(
p
+
r
*
plen
)
*
batch
start
=
(
p
+
r
*
plen
)
*
batch
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#
P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return
out
,
out_p
return
out
,
out_p
### make regions from ratios
### make regions from ratios
### ";" makes outercells, "," makes inner cells
### ";" makes outercells, "," makes inner cells
def
make_cells
(
ratios
):
def
make_cells
(
ratios
):
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
ratios
=
ratios
.
split
(
";"
)
ratios
=
ratios
.
split
(
";"
)
ratios
=
[
inratios
.
split
(
","
)
for
inratios
in
ratios
]
ratios
=
[
inratios
.
split
(
","
)
for
inratios
in
ratios
]
icells
=
[]
icells
=
[]
ocells
=
[]
ocells
=
[]
def
startend
(
cells
,
array
):
def
startend
(
cells
,
array
):
current_start
=
0
current_start
=
0
array
=
[
float
(
x
)
for
x
in
array
]
array
=
[
float
(
x
)
for
x
in
array
]
for
value
in
array
:
for
value
in
array
:
...
@@ -377,72 +421,80 @@ def make_cells(ratios):
...
@@ -377,72 +421,80 @@ def make_cells(ratios):
cells
.
append
([
current_start
,
end
])
cells
.
append
([
current_start
,
end
])
current_start
=
end
current_start
=
end
startend
(
ocells
,[
r
[
0
]
for
r
in
ratios
])
startend
(
ocells
,
[
r
[
0
]
for
r
in
ratios
])
for
inratios
in
ratios
:
for
inratios
in
ratios
:
if
2
>
len
(
inratios
):
if
2
>
len
(
inratios
):
icells
.
append
([[
0
,
1
]])
icells
.
append
([[
0
,
1
]])
else
:
else
:
add
=
[]
add
=
[]
startend
(
add
,
inratios
[
1
:])
startend
(
add
,
inratios
[
1
:])
icells
.
append
(
add
)
icells
.
append
(
add
)
return
ocells
,
icells
,
sum
(
len
(
cell
)
for
cell
in
icells
)
return
ocells
,
icells
,
sum
(
len
(
cell
)
for
cell
in
icells
)
def
make_emblist
(
self
,
prompts
):
def
make_emblist
(
self
,
prompts
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
tokens
=
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
.
to
(
self
.
device
)
tokens
=
self
.
tokenizer
(
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
input_ids
.
to
(
self
.
device
)
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
return
embs
return
embs
import
math
def
split_dims
(
xs
,
height
,
width
):
def
split_dims
(
xs
,
height
,
width
):
xs
=
xs
xs
=
xs
def
repeat_div
(
x
,
y
):
def
repeat_div
(
x
,
y
):
while
y
>
0
:
while
y
>
0
:
x
=
math
.
ceil
(
x
/
2
)
x
=
math
.
ceil
(
x
/
2
)
y
=
y
-
1
y
=
y
-
1
return
x
return
x
scale
=
math
.
ceil
(
math
.
log2
(
math
.
sqrt
(
height
*
width
/
xs
)))
scale
=
math
.
ceil
(
math
.
log2
(
math
.
sqrt
(
height
*
width
/
xs
)))
dsh
=
repeat_div
(
height
,
scale
)
dsh
=
repeat_div
(
height
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
return
dsh
,
dsw
return
dsh
,
dsw
##### for prompt mode
##### for prompt mode
def
get_attn_maps
(
self
,
attn
):
def
get_attn_maps
(
self
,
attn
):
height
,
width
=
self
.
hw
height
,
width
=
self
.
hw
target_tokens
=
self
.
target_tokens
target_tokens
=
self
.
target_tokens
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
self
.
attnmaps_sizes
.
append
((
height
,
width
))
self
.
attnmaps_sizes
.
append
((
height
,
width
))
for
b
in
range
(
self
.
batch
):
for
b
in
range
(
self
.
batch
):
for
t
in
target_tokens
:
for
t
in
target_tokens
:
power
=
self
.
power
power
=
self
.
power
add
=
attn
[
b
,:,:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
attn
[
b
,
:,
:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
key
=
f
"
{
t
}
-
{
b
}
"
key
=
f
"
{
t
}
-
{
b
}
"
if
key
not
in
self
.
attnmaps
:
if
key
not
in
self
.
attnmaps
:
self
.
attnmaps
[
key
]
=
add
self
.
attnmaps
[
key
]
=
add
else
:
else
:
if
self
.
attnmaps
[
key
].
shape
[
1
]
!=
add
.
shape
[
1
]:
if
self
.
attnmaps
[
key
].
shape
[
1
]
!=
add
.
shape
[
1
]:
add
=
add
.
view
(
8
,
height
,
width
)
add
=
add
.
view
(
8
,
height
,
width
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
add
.
reshape_as
(
self
.
attnmaps
[
key
])
add
=
add
.
reshape_as
(
self
.
attnmaps
[
key
])
self
.
attnmaps
[
key
]
=
self
.
attnmaps
[
key
]
+
add
self
.
attnmaps
[
key
]
=
self
.
attnmaps
[
key
]
+
add
def
reset_attnmaps
(
self
):
# init parameters in every batch
def
reset_attnmaps
(
self
):
# init parameters in every batch
self
.
step
=
0
self
.
step
=
0
self
.
attnmaps
=
{}
#maked from attention maps
self
.
attnmaps
=
{}
#
maked from attention maps
self
.
attnmaps_sizes
=
[]
#height,width set of u-net blocks
self
.
attnmaps_sizes
=
[]
#
height,width set of u-net blocks
self
.
attnmasks
=
{}
#maked from attnmaps for regions
self
.
attnmasks
=
{}
#
maked from attnmaps for regions
self
.
maskready
=
False
self
.
maskready
=
False
self
.
history
=
{}
self
.
history
=
{}
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
masks
=
[]
masks
=
[]
for
i
,
mask
in
enumerate
(
self
.
history
[
step
].
values
()):
for
i
,
mask
in
enumerate
(
self
.
history
[
step
].
values
()):
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
if
self
.
ex
:
if
self
.
ex
:
masks
=
[
x
-
mask
for
x
in
masks
]
masks
=
[
x
-
mask
for
x
in
masks
]
masks
.
append
(
mask
)
masks
.
append
(
mask
)
...
@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
...
@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
else
:
else
:
output
.
images
.
append
(
img
)
output
.
images
.
append
(
img
)
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
th
=
th
-
step
*
0.005
th
=
th
-
step
*
0.005
if
0.05
>=
th
:
th
=
0.05
if
0.05
>=
th
:
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
th
=
0.05
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
mask
=
mask
/
mask
.
max
().
item
()
mask
=
mask
/
mask
.
max
().
item
()
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
mask
.
float
()
mask
=
mask
.
float
()
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
img
=
FF
.
to_pil_image
(
mask
)
img
=
FF
.
to_pil_image
(
mask
)
img
=
img
.
resize
((
w
,
h
))
img
=
img
.
resize
((
w
,
h
))
mask
=
FF
.
resize
(
mask
,(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
mask
=
FF
.
resize
(
mask
,
(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
lmask
=
mask
lmask
=
mask
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
return
img
,
mask
,
lmask
return
img
,
mask
,
lmask
def
tokendealer
(
self
,
all_prompts
):
def
tokendealer
(
self
,
all_prompts
):
for
prompts
in
all_prompts
:
for
prompts
in
all_prompts
:
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
tt
=
[]
tt
=
[]
for
target
in
targets
:
for
target
in
targets
:
ptokens
=
(
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
ptokens
=
(
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
tlist
=
[]
tlist
=
[]
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
p
in
range
(
ptokens
.
shape
[
0
]):
for
p
in
range
(
ptokens
.
shape
[
0
]):
if
ttokens
[
t
+
1
]
==
ptokens
[
p
]:
if
ttokens
[
t
+
1
]
==
ptokens
[
p
]:
tlist
.
append
(
p
)
tlist
.
append
(
p
)
if
tlist
!=
[]
:
tt
.
append
(
tlist
)
if
tlist
!=
[]:
tt
.
append
(
tlist
)
return
tt
return
tt
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
# Efficient implementation equivalent to the following:
# Efficient implementation equivalent to the following:
L
,
S
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
L
,
S
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
scale_factor
=
1
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
scale
is
None
else
scale
scale_factor
=
1
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
scale
is
None
else
scale
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
if
is_causal
:
if
is_causal
:
assert
attn_mask
is
None
assert
attn_mask
is
None
temp_mask
=
torch
.
ones
(
L
,
S
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
0
)
temp_mask
=
torch
.
ones
(
L
,
S
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
0
)
...
@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
...
@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
+=
attn_bias
attn_weight
+=
attn_bias
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
attn_weight
=
torch
.
dropout
(
attn_weight
,
dropout_p
,
train
=
True
)
attn_weight
=
torch
.
dropout
(
attn_weight
,
dropout_p
,
train
=
True
)
return
attn_weight
@
value
return
attn_weight
@
value
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment