Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
539ff487
Commit
539ff487
authored
Apr 03, 2023
by
comfyanonymous
Browse files
Pull latest tomesd code from upstream.
parent
f50b1fec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
21 deletions
+48
-21
comfy/ldm/modules/tomesd.py
comfy/ldm/modules/tomesd.py
+48
-21
No files found.
comfy/ldm/modules/tomesd.py
View file @
539ff487
#Taken from: https://github.com/dbolya/tomesd
import
torch
from
typing
import
Tuple
,
Callable
...
...
@@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return
x
def
mps_gather_workaround
(
input
,
dim
,
index
):
if
input
.
shape
[
-
1
]
==
1
:
return
torch
.
gather
(
input
.
unsqueeze
(
-
1
),
dim
-
1
if
dim
<
0
else
dim
,
index
.
unsqueeze
(
-
1
)
).
squeeze
(
-
1
)
else
:
return
torch
.
gather
(
input
,
dim
,
index
)
def
bipartite_soft_matching_random2d
(
metric
:
torch
.
Tensor
,
w
:
int
,
h
:
int
,
sx
:
int
,
sy
:
int
,
r
:
int
,
no_rand
:
bool
=
False
)
->
Tuple
[
Callable
,
Callable
]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
...
...
@@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if
r
<=
0
:
return
do_nothing
,
do_nothing
gather
=
mps_gather_workaround
if
metric
.
device
.
type
==
"mps"
else
torch
.
gather
with
torch
.
no_grad
():
hsy
,
wsx
=
h
//
sy
,
w
//
sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
sy
*
sx
,
1
,
device
=
metric
.
device
)
if
no_rand
:
rand_idx
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
1
,
1
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
rand_idx
=
torch
.
zeros
(
hsy
,
wsx
,
1
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
else
:
rand_idx
=
torch
.
randint
(
sy
*
sx
,
size
=
(
1
,
hsy
,
wsx
,
1
,
1
),
device
=
metric
.
device
)
rand_idx
=
torch
.
randint
(
sy
*
sx
,
size
=
(
hsy
,
wsx
,
1
),
device
=
metric
.
device
)
idx_buffer
.
scatter_
(
dim
=
3
,
index
=
rand_idx
,
src
=-
torch
.
ones_like
(
rand_idx
,
dtype
=
idx_buffer
.
dtype
))
idx_buffer
=
idx_buffer
.
view
(
1
,
hsy
,
wsx
,
sy
,
sx
,
1
).
transpose
(
2
,
3
).
reshape
(
1
,
N
,
1
)
rand_idx
=
idx_buffer
.
argsort
(
dim
=
1
)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view
=
torch
.
zeros
(
hsy
,
wsx
,
sy
*
sx
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
idx_buffer_view
.
scatter_
(
dim
=
2
,
index
=
rand_idx
,
src
=-
torch
.
ones_like
(
rand_idx
,
dtype
=
rand_idx
.
dtype
))
idx_buffer_view
=
idx_buffer_view
.
view
(
hsy
,
wsx
,
sy
,
sx
).
transpose
(
1
,
2
).
reshape
(
hsy
*
sy
,
wsx
*
sx
)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if
(
hsy
*
sy
)
<
h
or
(
wsx
*
sx
)
<
w
:
idx_buffer
=
torch
.
zeros
(
h
,
w
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
idx_buffer
[:(
hsy
*
sy
),
:(
wsx
*
sx
)]
=
idx_buffer_view
else
:
idx_buffer
=
idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx
=
idx_buffer
.
reshape
(
1
,
-
1
,
1
).
argsort
(
dim
=
1
)
# We're finished with these
del
idx_buffer
,
idx_buffer_view
num_dst
=
int
((
1
/
(
sx
*
sy
))
*
N
)
# rand_idx is currently dst|src, so split them
num_dst
=
hsy
*
wsx
a_idx
=
rand_idx
[:,
num_dst
:,
:]
# src
b_idx
=
rand_idx
[:,
:
num_dst
,
:]
# dst
def
split
(
x
):
C
=
x
.
shape
[
-
1
]
src
=
x
.
gather
(
dim
=
1
,
index
=
a_idx
.
expand
(
B
,
N
-
num_dst
,
C
))
dst
=
x
.
gather
(
dim
=
1
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
C
))
src
=
gather
(
x
,
dim
=
1
,
index
=
a_idx
.
expand
(
B
,
N
-
num_dst
,
C
))
dst
=
gather
(
x
,
dim
=
1
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
C
))
return
src
,
dst
# Cosine similarity between A and B
metric
=
metric
/
metric
.
norm
(
dim
=-
1
,
keepdim
=
True
)
a
,
b
=
split
(
metric
)
scores
=
a
@
b
.
transpose
(
-
1
,
-
2
)
...
...
@@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src
r
=
min
(
a
.
shape
[
1
],
r
)
# Find the most similar greedily
node_max
,
node_idx
=
scores
.
max
(
dim
=-
1
)
edge_idx
=
node_max
.
argsort
(
dim
=-
1
,
descending
=
True
)[...,
None
]
unm_idx
=
edge_idx
[...,
r
:,
:]
# Unmerged Tokens
src_idx
=
edge_idx
[...,
:
r
,
:]
# Merged Tokens
dst_idx
=
node_idx
[...,
None
]
.
gather
(
dim
=-
2
,
index
=
src_idx
)
dst_idx
=
gather
(
node_idx
[...,
None
]
,
dim
=-
2
,
index
=
src_idx
)
def
merge
(
x
:
torch
.
Tensor
,
mode
=
"mean"
)
->
torch
.
Tensor
:
src
,
dst
=
split
(
x
)
n
,
t1
,
c
=
src
.
shape
unm
=
src
.
gather
(
dim
=-
2
,
index
=
unm_idx
.
expand
(
n
,
t1
-
r
,
c
))
src
=
src
.
gather
(
dim
=-
2
,
index
=
src_idx
.
expand
(
n
,
r
,
c
))
unm
=
gather
(
src
,
dim
=-
2
,
index
=
unm_idx
.
expand
(
n
,
t1
-
r
,
c
))
src
=
gather
(
src
,
dim
=-
2
,
index
=
src_idx
.
expand
(
n
,
r
,
c
))
dst
=
dst
.
scatter_reduce
(
-
2
,
dst_idx
.
expand
(
n
,
r
,
c
),
src
,
reduce
=
mode
)
return
torch
.
cat
([
unm
,
dst
],
dim
=
1
)
...
...
@@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm
,
dst
=
x
[...,
:
unm_len
,
:],
x
[...,
unm_len
:,
:]
_
,
_
,
c
=
unm
.
shape
src
=
dst
.
gather
(
dim
=-
2
,
index
=
dst_idx
.
expand
(
B
,
r
,
c
))
src
=
gather
(
dst
,
dim
=-
2
,
index
=
dst_idx
.
expand
(
B
,
r
,
c
))
# Combine back to the original shape
out
=
torch
.
zeros
(
B
,
N
,
c
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
out
.
scatter_
(
dim
=-
2
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
c
),
src
=
dst
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
.
gather
(
dim
=
1
,
index
=
unm_idx
).
expand
(
B
,
unm_len
,
c
),
src
=
unm
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
.
gather
(
dim
=
1
,
index
=
src_idx
).
expand
(
B
,
r
,
c
),
src
=
src
)
out
.
scatter_
(
dim
=-
2
,
index
=
gather
(
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
,
dim
=
1
,
index
=
unm_idx
).
expand
(
B
,
unm_len
,
c
),
src
=
unm
)
out
.
scatter_
(
dim
=-
2
,
index
=
gather
(
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
,
dim
=
1
,
index
=
src_idx
).
expand
(
B
,
r
,
c
),
src
=
src
)
return
out
...
...
@@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def
get_functions
(
x
,
ratio
,
original_shape
):
b
,
c
,
original_h
,
original_w
=
original_shape
original_tokens
=
original_h
*
original_w
downsample
=
int
(
math
.
sqrt
(
original_tokens
//
x
.
shape
[
1
]))
downsample
=
int
(
math
.
ceil
(
math
.
sqrt
(
original_tokens
//
x
.
shape
[
1
]))
)
stride_x
=
2
stride_y
=
2
max_downsample
=
1
if
downsample
<=
max_downsample
:
w
=
original_w
/
/
downsample
h
=
original_h
/
/
downsample
w
=
int
(
math
.
ceil
(
original_w
/
downsample
))
h
=
int
(
math
.
ceil
(
original_h
/
downsample
))
r
=
int
(
x
.
shape
[
1
]
*
ratio
)
no_rand
=
False
m
,
u
=
bipartite_soft_matching_random2d
(
x
,
w
,
h
,
stride_x
,
stride_y
,
r
,
no_rand
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment