Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0f55c17e
Commit
0f55c17e
authored
Dec 01, 2023
by
Patrick von Platen
Browse files
fix style
parent
5058d27f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
182 additions
and
104 deletions
+182
-104
examples/community/regional_prompting_stable_diffusion.py
examples/community/regional_prompting_stable_diffusion.py
+182
-104
No files found.
examples/community/regional_prompting_stable_diffusion.py
View file @
0f55c17e
import
torchvision.transforms.functional
as
FF
import
math
import
torch
import
torchvision
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
torch
import
torchvision.transforms.functional
as
FF
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers
import
StableDiffusionPipeline
from
diffusers
import
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.utils
import
USE_PEFT_BACKEND
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.schedulers
import
KarrasDiffusionSchedulers
from
diffusers.schedulers
import
KarrasDiffusionSchedulers
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers.utils
import
USE_PEFT_BACKEND
try
:
try
:
from
compel
import
Compel
from
compel
import
Compel
except
:
except
ImportError
:
Compel
=
None
Compel
=
None
KCOMM
=
"ADDCOMM"
KCOMM
=
"ADDCOMM"
KBRK
=
"BREAK"
KBRK
=
"BREAK"
class
RegionalPromptingStableDiffusionPipeline
(
StableDiffusionPipeline
):
class
RegionalPromptingStableDiffusionPipeline
(
StableDiffusionPipeline
):
r
"""
r
"""
Args for Regional Prompting Pipeline:
Args for Regional Prompting Pipeline:
...
@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
vae
:
AutoencoderKL
,
vae
:
AutoencoderKL
,
...
@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor
:
CLIPFeatureExtractor
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
requires_safety_checker
:
bool
=
True
,
):
):
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
self
.
register_modules
(
self
.
register_modules
(
vae
=
vae
,
vae
=
vae
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
...
@@ -93,32 +100,34 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -93,32 +100,34 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
):
):
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
# noqa: E721
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
if
negative_prompt
is
None
:
if
negative_prompt
is
None
:
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
# noqa: E721
device
=
self
.
_execution_device
device
=
self
.
_execution_device
regions
=
0
regions
=
0
self
.
power
=
int
(
rp_args
[
"power"
])
if
"power"
in
rp_args
else
1
self
.
power
=
int
(
rp_args
[
"power"
])
if
"power"
in
rp_args
else
1
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
# noqa: E721
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
# noqa: E721
self
.
batch
=
batch
=
num_images_per_prompt
*
len
(
prompts
)
self
.
batch
=
batch
=
num_images_per_prompt
*
len
(
prompts
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
cn
=
len
(
all_prompts_cn
)
==
len
(
all_n_prompts_cn
)
cn
=
len
(
all_prompts_cn
)
==
len
(
all_n_prompts_cn
)
if
Compel
:
if
Compel
:
compel
=
Compel
(
tokenizer
=
self
.
tokenizer
,
text_encoder
=
self
.
text_encoder
)
compel
=
Compel
(
tokenizer
=
self
.
tokenizer
,
text_encoder
=
self
.
text_encoder
)
def
getcompelembs
(
prps
):
def
getcompelembs
(
prps
):
embl
=
[]
embl
=
[]
for
prp
in
prps
:
for
prp
in
prps
:
embl
.
append
(
compel
.
build_conditioning_tensor
(
prp
))
embl
.
append
(
compel
.
build_conditioning_tensor
(
prp
))
return
torch
.
cat
(
embl
)
return
torch
.
cat
(
embl
)
conds
=
getcompelembs
(
all_prompts_cn
)
conds
=
getcompelembs
(
all_prompts_cn
)
unconds
=
getcompelembs
(
all_n_prompts_cn
)
if
cn
else
getcompelembs
(
n_prompts
)
unconds
=
getcompelembs
(
all_n_prompts_cn
)
if
cn
else
getcompelembs
(
n_prompts
)
embs
=
getcompelembs
(
prompts
)
embs
=
getcompelembs
(
prompts
)
...
@@ -126,16 +135,20 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -126,16 +135,20 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
prompt
=
negative_prompt
=
None
prompt
=
negative_prompt
=
None
else
:
else
:
conds
=
self
.
encode_prompt
(
prompts
,
device
,
1
,
True
)[
0
]
conds
=
self
.
encode_prompt
(
prompts
,
device
,
1
,
True
)[
0
]
unconds
=
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
unconds
=
(
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
)
embs
=
n_embs
=
None
embs
=
n_embs
=
None
if
not
active
:
if
not
active
:
pcallback
=
None
pcallback
=
None
mode
=
None
mode
=
None
else
:
else
:
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
mode
=
"COL"
if
"COL"
in
rp_args
[
"mode"
].
upper
()
else
"ROW"
mode
=
"COL"
if
"COL"
in
rp_args
[
"mode"
].
upper
()
else
"ROW"
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
elif
"PRO"
in
rp_args
[
"mode"
].
upper
():
elif
"PRO"
in
rp_args
[
"mode"
].
upper
():
regions
=
len
(
all_prompts_p
[
0
])
regions
=
len
(
all_prompts_p
[
0
])
...
@@ -145,10 +158,10 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -145,10 +158,10 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self
.
target_tokens
=
target_tokens
=
tokendealer
(
self
,
all_prompts_p
)
self
.
target_tokens
=
target_tokens
=
tokendealer
(
self
,
all_prompts_p
)
thresholds
=
[
float
(
x
)
for
x
in
rp_args
[
"th"
].
split
(
","
)]
thresholds
=
[
float
(
x
)
for
x
in
rp_args
[
"th"
].
split
(
","
)]
orig_hw
=
(
height
,
width
)
orig_hw
=
(
height
,
width
)
revers
=
True
revers
=
True
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
if
"PRO"
in
mode
:
# in Prompt mode, make masks from sum of attension maps
if
"PRO"
in
mode
:
# in Prompt mode, make masks from sum of attension maps
self
.
step
=
step
self
.
step
=
step
...
@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
allmasks
[
b
::
batch
]
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
allmasks
[
b
::
batch
]]
allmasks
[
b
::
batch
]
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
allmasks
[
b
::
batch
]]
allmasks
.
append
(
mask
)
allmasks
.
append
(
mask
)
basemasks
[
b
]
=
mask
if
basemasks
[
b
]
is
None
else
basemasks
[
b
]
+
mask
basemasks
[
b
]
=
mask
if
basemasks
[
b
]
is
None
else
basemasks
[
b
]
+
mask
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
basemasks
]
basemasks
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
basemasks
]
allmasks
=
basemasks
+
allmasks
allmasks
=
basemasks
+
allmasks
...
@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return
latents
return
latents
def
hook_forward
(
module
):
def
hook_forward
(
module
):
#diffusers==0.23.2
#
diffusers==0.23.2
def
forward
(
def
forward
(
hidden_states
:
torch
.
FloatTensor
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
attn
=
module
attn
=
module
xshape
=
hidden_states
.
shape
xshape
=
hidden_states
.
shape
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
if
revers
:
if
revers
:
nx
,
px
=
hidden_states
.
chunk
(
2
)
nx
,
px
=
hidden_states
.
chunk
(
2
)
else
:
else
:
px
,
nx
=
hidden_states
.
chunk
(
2
)
px
,
nx
=
hidden_states
.
chunk
(
2
)
if
cn
:
if
cn
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
else
:
else
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
residual
=
hidden_states
residual
=
hidden_states
...
@@ -247,7 +259,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -247,7 +259,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
scaled_dot_product_attention
(
hidden_states
=
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
,
)
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
...
@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center
=
reshaped
.
shape
[
0
]
//
2
center
=
reshaped
.
shape
[
0
]
//
2
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
for
out
in
outs
:
for
out
in
outs
:
c
=
0
c
=
0
for
i
,
ocell
in
enumerate
(
ocells
):
for
i
,
ocell
in
enumerate
(
ocells
):
for
icell
in
icells
[
i
]:
for
icell
in
icells
[
i
]:
if
"ROW"
in
mode
:
if
"ROW"
in
mode
:
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
else
:
else
:
out
[
0
:
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
c
+=
1
c
+=
1
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
hidden_states
.
reshape
(
xshape
)
hidden_states
=
hidden_states
.
reshape
(
xshape
)
#### Regional Prompting Prompt mode
#### Regional Prompting Prompt mode
...
@@ -292,16 +331,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -292,16 +331,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
def
mask
(
input
):
def
mask
(
input
):
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
for
r
in
range
(
1
,
regions
):
for
r
in
range
(
1
,
regions
):
out
[
b
]
=
out
[
b
]
+
out
[
r
*
batch
+
b
]
out
[
b
]
=
out
[
b
]
+
out
[
r
*
batch
+
b
]
return
out
return
out
px
,
nx
=
(
mask
(
px
),
mask
(
nx
))
if
cn
else
(
mask
(
px
),
nx
)
px
,
nx
=
(
mask
(
px
),
mask
(
nx
))
if
cn
else
(
mask
(
px
),
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
return
hidden_states
return
hidden_states
return
forward
return
forward
...
@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
=
latents
,
latents
=
latents
,
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback_on_step_end
=
pcallback
callback_on_step_end
=
pcallback
,
)
)
if
"save_mask"
in
rp_args
:
if
"save_mask"
in
rp_args
:
...
@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
...
@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else
:
else
:
save_mask
=
False
save_mask
=
False
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
return
output
return
output
### Make prompt list for each regions
### Make prompt list for each regions
def
promptsmaker
(
prompts
,
batch
):
def
promptsmaker
(
prompts
,
batch
):
out_p
=
[]
out_p
=
[]
plen
=
len
(
prompts
)
plen
=
len
(
prompts
)
for
prompt
in
prompts
:
for
prompt
in
prompts
:
...
@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
...
@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
add
=
add
+
" "
add
=
add
+
" "
prompts
=
prompt
.
split
(
KBRK
)
prompts
=
prompt
.
split
(
KBRK
)
out_p
.
append
([
add
+
p
for
p
in
prompts
])
out_p
.
append
([
add
+
p
for
p
in
prompts
])
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
for
p
,
prs
in
enumerate
(
out_p
):
# inputs prompts
for
p
,
prs
in
enumerate
(
out_p
):
# inputs prompts
for
r
,
pr
in
enumerate
(
prs
):
# prompts for regions
for
r
,
pr
in
enumerate
(
prs
):
# prompts for regions
start
=
(
p
+
r
*
plen
)
*
batch
start
=
(
p
+
r
*
plen
)
*
batch
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#
P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return
out
,
out_p
return
out
,
out_p
### make regions from ratios
### make regions from ratios
### ";" makes outercells, "," makes inner cells
### ";" makes outercells, "," makes inner cells
def
make_cells
(
ratios
):
def
make_cells
(
ratios
):
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
ratios
=
ratios
.
split
(
";"
)
ratios
=
ratios
.
split
(
";"
)
ratios
=
[
inratios
.
split
(
","
)
for
inratios
in
ratios
]
ratios
=
[
inratios
.
split
(
","
)
for
inratios
in
ratios
]
icells
=
[]
icells
=
[]
ocells
=
[]
ocells
=
[]
def
startend
(
cells
,
array
):
def
startend
(
cells
,
array
):
current_start
=
0
current_start
=
0
array
=
[
float
(
x
)
for
x
in
array
]
array
=
[
float
(
x
)
for
x
in
array
]
for
value
in
array
:
for
value
in
array
:
...
@@ -377,72 +421,80 @@ def make_cells(ratios):
...
@@ -377,72 +421,80 @@ def make_cells(ratios):
cells
.
append
([
current_start
,
end
])
cells
.
append
([
current_start
,
end
])
current_start
=
end
current_start
=
end
startend
(
ocells
,[
r
[
0
]
for
r
in
ratios
])
startend
(
ocells
,
[
r
[
0
]
for
r
in
ratios
])
for
inratios
in
ratios
:
for
inratios
in
ratios
:
if
2
>
len
(
inratios
):
if
2
>
len
(
inratios
):
icells
.
append
([[
0
,
1
]])
icells
.
append
([[
0
,
1
]])
else
:
else
:
add
=
[]
add
=
[]
startend
(
add
,
inratios
[
1
:])
startend
(
add
,
inratios
[
1
:])
icells
.
append
(
add
)
icells
.
append
(
add
)
return
ocells
,
icells
,
sum
(
len
(
cell
)
for
cell
in
icells
)
return
ocells
,
icells
,
sum
(
len
(
cell
)
for
cell
in
icells
)
def
make_emblist
(
self
,
prompts
):
def
make_emblist
(
self
,
prompts
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
tokens
=
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
.
to
(
self
.
device
)
tokens
=
self
.
tokenizer
(
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
input_ids
.
to
(
self
.
device
)
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
return
embs
return
embs
import
math
def
split_dims
(
xs
,
height
,
width
):
def
split_dims
(
xs
,
height
,
width
):
xs
=
xs
xs
=
xs
def
repeat_div
(
x
,
y
):
def
repeat_div
(
x
,
y
):
while
y
>
0
:
while
y
>
0
:
x
=
math
.
ceil
(
x
/
2
)
x
=
math
.
ceil
(
x
/
2
)
y
=
y
-
1
y
=
y
-
1
return
x
return
x
scale
=
math
.
ceil
(
math
.
log2
(
math
.
sqrt
(
height
*
width
/
xs
)))
scale
=
math
.
ceil
(
math
.
log2
(
math
.
sqrt
(
height
*
width
/
xs
)))
dsh
=
repeat_div
(
height
,
scale
)
dsh
=
repeat_div
(
height
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
return
dsh
,
dsw
return
dsh
,
dsw
##### for prompt mode
##### for prompt mode
def
get_attn_maps
(
self
,
attn
):
def
get_attn_maps
(
self
,
attn
):
height
,
width
=
self
.
hw
height
,
width
=
self
.
hw
target_tokens
=
self
.
target_tokens
target_tokens
=
self
.
target_tokens
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
self
.
attnmaps_sizes
.
append
((
height
,
width
))
self
.
attnmaps_sizes
.
append
((
height
,
width
))
for
b
in
range
(
self
.
batch
):
for
b
in
range
(
self
.
batch
):
for
t
in
target_tokens
:
for
t
in
target_tokens
:
power
=
self
.
power
power
=
self
.
power
add
=
attn
[
b
,:,:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
attn
[
b
,
:,
:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
key
=
f
"
{
t
}
-
{
b
}
"
key
=
f
"
{
t
}
-
{
b
}
"
if
key
not
in
self
.
attnmaps
:
if
key
not
in
self
.
attnmaps
:
self
.
attnmaps
[
key
]
=
add
self
.
attnmaps
[
key
]
=
add
else
:
else
:
if
self
.
attnmaps
[
key
].
shape
[
1
]
!=
add
.
shape
[
1
]:
if
self
.
attnmaps
[
key
].
shape
[
1
]
!=
add
.
shape
[
1
]:
add
=
add
.
view
(
8
,
height
,
width
)
add
=
add
.
view
(
8
,
height
,
width
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
add
.
reshape_as
(
self
.
attnmaps
[
key
])
add
=
add
.
reshape_as
(
self
.
attnmaps
[
key
])
self
.
attnmaps
[
key
]
=
self
.
attnmaps
[
key
]
+
add
self
.
attnmaps
[
key
]
=
self
.
attnmaps
[
key
]
+
add
def
reset_attnmaps
(
self
):
# init parameters in every batch
def
reset_attnmaps
(
self
):
# init parameters in every batch
self
.
step
=
0
self
.
step
=
0
self
.
attnmaps
=
{}
#maked from attention maps
self
.
attnmaps
=
{}
#
maked from attention maps
self
.
attnmaps_sizes
=
[]
#height,width set of u-net blocks
self
.
attnmaps_sizes
=
[]
#
height,width set of u-net blocks
self
.
attnmasks
=
{}
#maked from attnmaps for regions
self
.
attnmasks
=
{}
#
maked from attnmaps for regions
self
.
maskready
=
False
self
.
maskready
=
False
self
.
history
=
{}
self
.
history
=
{}
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
masks
=
[]
masks
=
[]
for
i
,
mask
in
enumerate
(
self
.
history
[
step
].
values
()):
for
i
,
mask
in
enumerate
(
self
.
history
[
step
].
values
()):
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
if
self
.
ex
:
if
self
.
ex
:
masks
=
[
x
-
mask
for
x
in
masks
]
masks
=
[
x
-
mask
for
x
in
masks
]
masks
.
append
(
mask
)
masks
.
append
(
mask
)
...
@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
...
@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
else
:
else
:
output
.
images
.
append
(
img
)
output
.
images
.
append
(
img
)
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
th
=
th
-
step
*
0.005
th
=
th
-
step
*
0.005
if
0.05
>=
th
:
th
=
0.05
if
0.05
>=
th
:
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
th
=
0.05
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
mask
=
mask
/
mask
.
max
().
item
()
mask
=
mask
/
mask
.
max
().
item
()
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
mask
.
float
()
mask
=
mask
.
float
()
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
img
=
FF
.
to_pil_image
(
mask
)
img
=
FF
.
to_pil_image
(
mask
)
img
=
img
.
resize
((
w
,
h
))
img
=
img
.
resize
((
w
,
h
))
mask
=
FF
.
resize
(
mask
,(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
mask
=
FF
.
resize
(
mask
,
(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
lmask
=
mask
lmask
=
mask
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
return
img
,
mask
,
lmask
return
img
,
mask
,
lmask
def
tokendealer
(
self
,
all_prompts
):
def
tokendealer
(
self
,
all_prompts
):
for
prompts
in
all_prompts
:
for
prompts
in
all_prompts
:
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
tt
=
[]
tt
=
[]
for
target
in
targets
:
for
target
in
targets
:
ptokens
=
(
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
ptokens
=
(
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
tlist
=
[]
tlist
=
[]
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
p
in
range
(
ptokens
.
shape
[
0
]):
for
p
in
range
(
ptokens
.
shape
[
0
]):
if
ttokens
[
t
+
1
]
==
ptokens
[
p
]:
if
ttokens
[
t
+
1
]
==
ptokens
[
p
]:
tlist
.
append
(
p
)
tlist
.
append
(
p
)
if
tlist
!=
[]
:
tt
.
append
(
tlist
)
if
tlist
!=
[]:
tt
.
append
(
tlist
)
return
tt
return
tt
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
# Efficient implementation equivalent to the following:
# Efficient implementation equivalent to the following:
L
,
S
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
L
,
S
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
scale_factor
=
1
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
scale
is
None
else
scale
scale_factor
=
1
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
scale
is
None
else
scale
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
if
is_causal
:
if
is_causal
:
assert
attn_mask
is
None
assert
attn_mask
is
None
temp_mask
=
torch
.
ones
(
L
,
S
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
0
)
temp_mask
=
torch
.
ones
(
L
,
S
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
0
)
...
@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
...
@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
+=
attn_bias
attn_weight
+=
attn_bias
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
attn_weight
=
torch
.
dropout
(
attn_weight
,
dropout_p
,
train
=
True
)
attn_weight
=
torch
.
dropout
(
attn_weight
,
dropout_p
,
train
=
True
)
return
attn_weight
@
value
return
attn_weight
@
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment