Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
0f55c17e
Commit
0f55c17e
authored
Dec 01, 2023
by
Patrick von Platen
Browse files
fix style
parent
5058d27f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
182 additions
and
104 deletions
+182
-104
examples/community/regional_prompting_stable_diffusion.py
examples/community/regional_prompting_stable_diffusion.py
+182
-104
No files found.
examples/community/regional_prompting_stable_diffusion.py
View file @
0f55c17e
import
torchvision.transforms.functional
as
FF
import
torch
import
torchvision
import
math
from
typing
import
Dict
,
Optional
import
torch
import
torchvision.transforms.functional
as
FF
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers
import
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.utils
import
USE_PEFT_BACKEND
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.schedulers
import
KarrasDiffusionSchedulers
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
diffusers.utils
import
USE_PEFT_BACKEND
try
:
from
compel
import
Compel
except
:
except
ImportError
:
Compel
=
None
KCOMM
=
"ADDCOMM"
KBRK
=
"BREAK"
class
RegionalPromptingStableDiffusionPipeline
(
StableDiffusionPipeline
):
r
"""
Args for Regional Prompting Pipeline:
...
...
@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def
__init__
(
self
,
vae
:
AutoencoderKL
,
...
...
@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
super
().
__init__
(
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
...
...
@@ -93,32 +100,34 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
rp_args
:
Dict
[
str
,
str
]
=
None
,
):
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
if
negative_prompt
is
None
:
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
active
=
KBRK
in
prompt
[
0
]
if
type
(
prompt
)
==
list
else
KBRK
in
prompt
# noqa: E721
if
negative_prompt
is
None
:
negative_prompt
=
""
if
type
(
prompt
)
==
str
else
[
""
]
*
len
(
prompt
)
# noqa: E721
device
=
self
.
_execution_device
regions
=
0
self
.
power
=
int
(
rp_args
[
"power"
])
if
"power"
in
rp_args
else
1
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
prompts
=
prompt
if
type
(
prompt
)
==
list
else
[
prompt
]
# noqa: E721
n_prompts
=
negative_prompt
if
type
(
negative_prompt
)
==
list
else
[
negative_prompt
]
# noqa: E721
self
.
batch
=
batch
=
num_images_per_prompt
*
len
(
prompts
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
all_prompts_cn
,
all_prompts_p
=
promptsmaker
(
prompts
,
num_images_per_prompt
)
all_n_prompts_cn
,
_
=
promptsmaker
(
n_prompts
,
num_images_per_prompt
)
cn
=
len
(
all_prompts_cn
)
==
len
(
all_n_prompts_cn
)
if
Compel
:
compel
=
Compel
(
tokenizer
=
self
.
tokenizer
,
text_encoder
=
self
.
text_encoder
)
def
getcompelembs
(
prps
):
embl
=
[]
for
prp
in
prps
:
embl
.
append
(
compel
.
build_conditioning_tensor
(
prp
))
return
torch
.
cat
(
embl
)
conds
=
getcompelembs
(
all_prompts_cn
)
unconds
=
getcompelembs
(
all_n_prompts_cn
)
if
cn
else
getcompelembs
(
n_prompts
)
embs
=
getcompelembs
(
prompts
)
...
...
@@ -126,16 +135,20 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
prompt
=
negative_prompt
=
None
else
:
conds
=
self
.
encode_prompt
(
prompts
,
device
,
1
,
True
)[
0
]
unconds
=
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
unconds
=
(
self
.
encode_prompt
(
n_prompts
,
device
,
1
,
True
)[
0
]
if
cn
else
self
.
encode_prompt
(
all_n_prompts_cn
,
device
,
1
,
True
)[
0
]
)
embs
=
n_embs
=
None
if
not
active
:
pcallback
=
None
mode
=
None
else
:
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
if
any
(
x
in
rp_args
[
"mode"
].
upper
()
for
x
in
[
"COL"
,
"ROW"
]):
mode
=
"COL"
if
"COL"
in
rp_args
[
"mode"
].
upper
()
else
"ROW"
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
ocells
,
icells
,
regions
=
make_cells
(
rp_args
[
"div"
])
elif
"PRO"
in
rp_args
[
"mode"
].
upper
():
regions
=
len
(
all_prompts_p
[
0
])
...
...
@@ -145,10 +158,10 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self
.
target_tokens
=
target_tokens
=
tokendealer
(
self
,
all_prompts_p
)
thresholds
=
[
float
(
x
)
for
x
in
rp_args
[
"th"
].
split
(
","
)]
orig_hw
=
(
height
,
width
)
orig_hw
=
(
height
,
width
)
revers
=
True
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
def
pcallback
(
s_self
,
step
:
int
,
timestep
:
int
,
latents
:
torch
.
FloatTensor
,
selfs
=
None
):
if
"PRO"
in
mode
:
# in Prompt mode, make masks from sum of attension maps
self
.
step
=
step
...
...
@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
allmasks
[
b
::
batch
]
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
allmasks
[
b
::
batch
]]
allmasks
.
append
(
mask
)
basemasks
[
b
]
=
mask
if
basemasks
[
b
]
is
None
else
basemasks
[
b
]
+
mask
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
1
-
mask
for
mask
in
basemasks
]
basemasks
=
[
torch
.
where
(
x
>
0
,
1
,
0
)
for
x
in
basemasks
]
allmasks
=
basemasks
+
allmasks
...
...
@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return
latents
def
hook_forward
(
module
):
#diffusers==0.23.2
#
diffusers==0.23.2
def
forward
(
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
...
@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
attn
=
module
xshape
=
hidden_states
.
shape
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
self
.
hw
=
(
h
,
w
)
=
split_dims
(
xshape
[
1
],
*
orig_hw
)
if
revers
:
nx
,
px
=
hidden_states
.
chunk
(
2
)
nx
,
px
=
hidden_states
.
chunk
(
2
)
else
:
px
,
nx
=
hidden_states
.
chunk
(
2
)
px
,
nx
=
hidden_states
.
chunk
(
2
)
if
cn
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
for
i
in
range
(
regions
)],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
else
:
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
hidden_states
=
torch
.
cat
([
px
for
i
in
range
(
regions
)]
+
[
nx
],
0
)
encoder_hidden_states
=
torch
.
cat
([
conds
]
+
[
unconds
])
residual
=
hidden_states
...
...
@@ -247,7 +259,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
self
,
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
getattn
=
"PRO"
in
mode
,
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
...
...
@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center
=
reshaped
.
shape
[
0
]
//
2
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
outs
=
[
px
,
nx
]
if
cn
else
[
px
]
for
out
in
outs
:
c
=
0
for
i
,
ocell
in
enumerate
(
ocells
):
for
i
,
ocell
in
enumerate
(
ocells
):
for
icell
in
icells
[
i
]:
if
"ROW"
in
mode
:
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
]):
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
]):
int
(
w
*
icell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
ocell
[
0
])
:
int
(
h
*
ocell
[
1
]),
int
(
w
*
icell
[
0
])
:
int
(
w
*
icell
[
1
]),
:,
]
else
:
out
[
0
:
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
=
out
[
c
*
batch
:(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
]):
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
]):
int
(
w
*
ocell
[
1
]),:]
out
[
0
:
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
=
out
[
c
*
batch
:
(
c
+
1
)
*
batch
,
int
(
h
*
icell
[
0
])
:
int
(
h
*
icell
[
1
]),
int
(
w
*
ocell
[
0
])
:
int
(
w
*
ocell
[
1
]),
:,
]
c
+=
1
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
hidden_states
.
reshape
(
xshape
)
#### Regional Prompting Prompt mode
...
...
@@ -292,16 +331,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
px
=
reshaped
[
0
:
center
]
if
cn
else
reshaped
[
0
:
-
batch
]
nx
=
reshaped
[
center
:]
if
cn
else
reshaped
[
-
batch
:]
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
if
(
h
,
w
)
in
self
.
attnmasks
and
self
.
maskready
:
def
mask
(
input
):
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
out
=
torch
.
multiply
(
input
,
self
.
attnmasks
[(
h
,
w
)])
for
b
in
range
(
batch
):
for
r
in
range
(
1
,
regions
):
out
[
b
]
=
out
[
b
]
+
out
[
r
*
batch
+
b
]
return
out
px
,
nx
=
(
mask
(
px
),
mask
(
nx
))
if
cn
else
(
mask
(
px
),
nx
)
px
,
nx
=
(
px
[
0
:
batch
],
nx
[
0
:
batch
])
if
cn
else
(
px
[
0
:
batch
],
nx
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
hidden_states
=
torch
.
cat
([
nx
,
px
],
0
)
if
revers
else
torch
.
cat
([
px
,
nx
],
0
)
return
hidden_states
return
forward
...
...
@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents
=
latents
,
output_type
=
output_type
,
return_dict
=
return_dict
,
callback_on_step_end
=
pcallback
callback_on_step_end
=
pcallback
,
)
if
"save_mask"
in
rp_args
:
...
...
@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else
:
save_mask
=
False
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
if
mode
==
"PROMPT"
and
save_mask
:
saveattnmaps
(
self
,
output
,
height
,
width
,
thresholds
,
num_inference_steps
//
2
,
regions
)
return
output
### Make prompt list for each regions
def
promptsmaker
(
prompts
,
batch
):
def
promptsmaker
(
prompts
,
batch
):
out_p
=
[]
plen
=
len
(
prompts
)
for
prompt
in
prompts
:
...
...
@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
add
=
add
+
" "
prompts
=
prompt
.
split
(
KBRK
)
out_p
.
append
([
add
+
p
for
p
in
prompts
])
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
out
=
[
None
]
*
batch
*
len
(
out_p
[
0
])
*
len
(
out_p
)
for
p
,
prs
in
enumerate
(
out_p
):
# inputs prompts
for
r
,
pr
in
enumerate
(
prs
):
# prompts for regions
start
=
(
p
+
r
*
plen
)
*
batch
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
out
[
start
:
start
+
batch
]
=
[
pr
]
*
batch
#
P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return
out
,
out_p
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def
make_cells
(
ratios
):
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
if
";"
not
in
ratios
and
","
in
ratios
:
ratios
=
ratios
.
replace
(
","
,
";"
)
ratios
=
ratios
.
split
(
";"
)
ratios
=
[
inratios
.
split
(
","
)
for
inratios
in
ratios
]
icells
=
[]
ocells
=
[]
def
startend
(
cells
,
array
):
def
startend
(
cells
,
array
):
current_start
=
0
array
=
[
float
(
x
)
for
x
in
array
]
for
value
in
array
:
...
...
@@ -377,72 +421,80 @@ def make_cells(ratios):
cells
.
append
([
current_start
,
end
])
current_start
=
end
startend
(
ocells
,[
r
[
0
]
for
r
in
ratios
])
startend
(
ocells
,
[
r
[
0
]
for
r
in
ratios
])
for
inratios
in
ratios
:
if
2
>
len
(
inratios
):
icells
.
append
([[
0
,
1
]])
icells
.
append
([[
0
,
1
]])
else
:
add
=
[]
startend
(
add
,
inratios
[
1
:])
startend
(
add
,
inratios
[
1
:])
icells
.
append
(
add
)
return
ocells
,
icells
,
sum
(
len
(
cell
)
for
cell
in
icells
)
def
make_emblist
(
self
,
prompts
):
with
torch
.
no_grad
():
tokens
=
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
.
to
(
self
.
device
)
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
tokens
=
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
input_ids
.
to
(
self
.
device
)
embs
=
self
.
text_encoder
(
tokens
,
output_hidden_states
=
True
).
last_hidden_state
.
to
(
self
.
device
,
dtype
=
self
.
dtype
)
return
embs
import
math
def
split_dims
(
xs
,
height
,
width
):
xs
=
xs
def
repeat_div
(
x
,
y
):
def
repeat_div
(
x
,
y
):
while
y
>
0
:
x
=
math
.
ceil
(
x
/
2
)
y
=
y
-
1
return
x
scale
=
math
.
ceil
(
math
.
log2
(
math
.
sqrt
(
height
*
width
/
xs
)))
dsh
=
repeat_div
(
height
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
return
dsh
,
dsw
dsh
=
repeat_div
(
height
,
scale
)
dsw
=
repeat_div
(
width
,
scale
)
return
dsh
,
dsw
##### for prompt mode
def
get_attn_maps
(
self
,
attn
):
height
,
width
=
self
.
hw
def
get_attn_maps
(
self
,
attn
):
height
,
width
=
self
.
hw
target_tokens
=
self
.
target_tokens
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
self
.
attnmaps_sizes
.
append
((
height
,
width
))
if
(
height
,
width
)
not
in
self
.
attnmaps_sizes
:
self
.
attnmaps_sizes
.
append
((
height
,
width
))
for
b
in
range
(
self
.
batch
):
for
t
in
target_tokens
:
power
=
self
.
power
add
=
attn
[
b
,:,:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
add
=
attn
[
b
,
:,
:,
t
[
0
]
:
t
[
0
]
+
len
(
t
)]
**
(
power
)
*
(
self
.
attnmaps_sizes
.
index
((
height
,
width
))
+
1
)
add
=
torch
.
sum
(
add
,
dim
=
2
)
key
=
f
"
{
t
}
-
{
b
}
"
if
key
not
in
self
.
attnmaps
:
self
.
attnmaps
[
key
]
=
add
else
:
if
self
.
attnmaps
[
key
].
shape
[
1
]
!=
add
.
shape
[
1
]:
add
=
add
.
view
(
8
,
height
,
width
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
add
.
view
(
8
,
height
,
width
)
add
=
FF
.
resize
(
add
,
self
.
attnmaps_sizes
[
0
],
antialias
=
None
)
add
=
add
.
reshape_as
(
self
.
attnmaps
[
key
])
self
.
attnmaps
[
key
]
=
self
.
attnmaps
[
key
]
+
add
def
reset_attnmaps
(
self
):
# init parameters in every batch
self
.
step
=
0
self
.
attnmaps
=
{}
#maked from attention maps
self
.
attnmaps_sizes
=
[]
#height,width set of u-net blocks
self
.
attnmasks
=
{}
#maked from attnmaps for regions
self
.
attnmaps
=
{}
#
maked from attention maps
self
.
attnmaps_sizes
=
[]
#
height,width set of u-net blocks
self
.
attnmasks
=
{}
#
maked from attnmaps for regions
self
.
maskready
=
False
self
.
history
=
{}
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
def
saveattnmaps
(
self
,
output
,
h
,
w
,
th
,
step
,
regions
):
masks
=
[]
for
i
,
mask
in
enumerate
(
self
.
history
[
step
].
values
()):
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
img
,
_
,
mask
=
makepmask
(
self
,
mask
,
h
,
w
,
th
[
i
%
len
(
th
)],
step
)
if
self
.
ex
:
masks
=
[
x
-
mask
for
x
in
masks
]
masks
.
append
(
mask
)
...
...
@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
else
:
output
.
images
.
append
(
img
)
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
def
makepmask
(
self
,
mask
,
h
,
w
,
th
,
step
):
# make masks from attention cache return [for preview, for attention, for Latent]
th
=
th
-
step
*
0.005
if
0.05
>=
th
:
th
=
0.05
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
if
0.05
>=
th
:
th
=
0.05
mask
=
torch
.
mean
(
mask
,
dim
=
0
)
mask
=
mask
/
mask
.
max
().
item
()
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
torch
.
where
(
mask
>
th
,
1
,
0
)
mask
=
mask
.
float
()
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
mask
=
mask
.
view
(
1
,
*
self
.
attnmaps_sizes
[
0
])
img
=
FF
.
to_pil_image
(
mask
)
img
=
img
.
resize
((
w
,
h
))
mask
=
FF
.
resize
(
mask
,(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
img
=
img
.
resize
((
w
,
h
))
mask
=
FF
.
resize
(
mask
,
(
h
,
w
),
interpolation
=
FF
.
InterpolationMode
.
NEAREST
,
antialias
=
None
)
lmask
=
mask
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
mask
=
mask
.
reshape
(
h
*
w
)
mask
=
torch
.
where
(
mask
>
0.1
,
1
,
0
)
return
img
,
mask
,
lmask
def
tokendealer
(
self
,
all_prompts
):
for
prompts
in
all_prompts
:
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
targets
=
[
p
.
split
(
","
)[
-
1
]
for
p
in
prompts
[
1
:]]
tt
=
[]
for
target
in
targets
:
ptokens
=
(
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
'pt'
).
input_ids
)[
0
]
ptokens
=
(
self
.
tokenizer
(
prompts
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
ttokens
=
(
self
.
tokenizer
(
target
,
max_length
=
self
.
tokenizer
.
model_max_length
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
).
input_ids
)[
0
]
tlist
=
[]
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
t
in
range
(
ttokens
.
shape
[
0
]
-
2
):
for
p
in
range
(
ptokens
.
shape
[
0
]):
if
ttokens
[
t
+
1
]
==
ptokens
[
p
]:
tlist
.
append
(
p
)
if
tlist
!=
[]
:
tt
.
append
(
tlist
)
if
tlist
!=
[]:
tt
.
append
(
tlist
)
return
tt
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
def
scaled_dot_product_attention
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
None
,
getattn
=
False
)
->
torch
.
Tensor
:
# Efficient implementation equivalent to the following:
L
,
S
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
scale_factor
=
1
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
scale
is
None
else
scale
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
attn_bias
=
torch
.
zeros
(
L
,
S
,
dtype
=
query
.
dtype
,
device
=
self
.
device
)
if
is_causal
:
assert
attn_mask
is
None
temp_mask
=
torch
.
ones
(
L
,
S
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
0
)
...
...
@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
+=
attn_bias
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
if
getattn
:
get_attn_maps
(
self
,
attn_weight
)
attn_weight
=
torch
.
dropout
(
attn_weight
,
dropout_p
,
train
=
True
)
return
attn_weight
@
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment