Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b8636a44
Commit
b8636a44
authored
May 20, 2023
by
comfyanonymous
Browse files
Make scaled_dot_product switch to sliced attention on OOM.
parent
797c4e8d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
36 deletions
+43
-36
comfy/ldm/modules/diffusionmodules/model.py
comfy/ldm/modules/diffusionmodules/model.py
+43
-36
No files found.
comfy/ldm/modules/diffusionmodules/model.py
View file @
b8636a44
...
@@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
...
@@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
return
x
+
h
return
x
+
h
def
slice_attention
(
q
,
k
,
v
):
r1
=
torch
.
zeros_like
(
k
,
device
=
q
.
device
)
scale
=
(
int
(
q
.
shape
[
-
1
])
**
(
-
0.5
))
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
2
]
*
q
.
element_size
()
modifier
=
3
if
q
.
element_size
()
==
2
else
2.5
mem_required
=
tensor_size
*
modifier
steps
=
1
if
mem_required
>
mem_free_total
:
steps
=
2
**
(
math
.
ceil
(
math
.
log
(
mem_required
/
mem_free_total
,
2
)))
while
True
:
try
:
slice_size
=
q
.
shape
[
1
]
//
steps
if
(
q
.
shape
[
1
]
%
steps
)
==
0
else
q
.
shape
[
1
]
for
i
in
range
(
0
,
q
.
shape
[
1
],
slice_size
):
end
=
i
+
slice_size
s1
=
torch
.
bmm
(
q
[:,
i
:
end
],
k
)
*
scale
s2
=
torch
.
nn
.
functional
.
softmax
(
s1
,
dim
=
2
).
permute
(
0
,
2
,
1
)
del
s1
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
del
s2
break
except
model_management
.
OOM_EXCEPTION
as
e
:
steps
*=
2
if
steps
>
128
:
raise
e
print
(
"out of memory error, increasing steps and trying again"
,
steps
)
return
r1
class
AttnBlock
(
nn
.
Module
):
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
...
@@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
...
@@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
b
,
c
,
h
,
w
=
q
.
shape
scale
=
(
int
(
c
)
**
(
-
0.5
))
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
r1
=
torch
.
zeros_like
(
k
,
device
=
q
.
device
)
r1
=
slice_attention
(
q
,
k
,
v
)
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
gb
=
1024
**
3
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
2
]
*
q
.
element_size
()
modifier
=
3
if
q
.
element_size
()
==
2
else
2.5
mem_required
=
tensor_size
*
modifier
steps
=
1
if
mem_required
>
mem_free_total
:
steps
=
2
**
(
math
.
ceil
(
math
.
log
(
mem_required
/
mem_free_total
,
2
)))
while
True
:
try
:
slice_size
=
q
.
shape
[
1
]
//
steps
if
(
q
.
shape
[
1
]
%
steps
)
==
0
else
q
.
shape
[
1
]
for
i
in
range
(
0
,
q
.
shape
[
1
],
slice_size
):
end
=
i
+
slice_size
s1
=
torch
.
bmm
(
q
[:,
i
:
end
],
k
)
*
scale
s2
=
torch
.
nn
.
functional
.
softmax
(
s1
,
dim
=
2
).
permute
(
0
,
2
,
1
)
del
s1
r1
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v
,
s2
)
del
s2
break
except
model_management
.
OOM_EXCEPTION
as
e
:
steps
*=
2
if
steps
>
128
:
raise
e
print
(
"out of memory error, increasing steps and trying again"
,
steps
)
h_
=
r1
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
r1
.
reshape
(
b
,
c
,
h
,
w
)
del
r1
del
r1
h_
=
self
.
proj_out
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
return
x
+
h_
...
@@ -335,9 +337,14 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
...
@@ -335,9 +337,14 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
lambda
t
:
t
.
view
(
B
,
1
,
C
,
-
1
).
transpose
(
2
,
3
).
contiguous
(),
lambda
t
:
t
.
view
(
B
,
1
,
C
,
-
1
).
transpose
(
2
,
3
).
contiguous
(),
(
q
,
k
,
v
),
(
q
,
k
,
v
),
)
)
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
)
try
:
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.0
,
is_causal
=
False
)
out
=
out
.
transpose
(
2
,
3
).
reshape
(
B
,
C
,
H
,
W
)
out
=
out
.
transpose
(
2
,
3
).
reshape
(
B
,
C
,
H
,
W
)
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"scaled_dot_product_attention OOMed: switched to slice attention"
)
out
=
slice_attention
(
q
.
view
(
B
,
-
1
,
C
),
k
.
view
(
B
,
-
1
,
C
).
transpose
(
1
,
2
),
v
.
view
(
B
,
-
1
,
C
).
transpose
(
1
,
2
)).
reshape
(
B
,
C
,
H
,
W
)
out
=
self
.
proj_out
(
out
)
out
=
self
.
proj_out
(
out
)
return
x
+
out
return
x
+
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment