Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
539ff487
"...composable_kernel_onnxruntime.git" did not exist on "6dfb4e7851a99eab605d239873b7eca777980fa8"
Commit
539ff487
authored
Apr 03, 2023
by
comfyanonymous
Browse files
Pull latest tomesd code from upstream.
parent
f50b1fec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
21 deletions
+48
-21
comfy/ldm/modules/tomesd.py
comfy/ldm/modules/tomesd.py
+48
-21
No files found.
comfy/ldm/modules/tomesd.py
View file @
539ff487
#Taken from: https://github.com/dbolya/tomesd
import
torch
import
torch
from
typing
import
Tuple
,
Callable
from
typing
import
Tuple
,
Callable
...
@@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
...
@@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return
x
return
x
def
mps_gather_workaround
(
input
,
dim
,
index
):
if
input
.
shape
[
-
1
]
==
1
:
return
torch
.
gather
(
input
.
unsqueeze
(
-
1
),
dim
-
1
if
dim
<
0
else
dim
,
index
.
unsqueeze
(
-
1
)
).
squeeze
(
-
1
)
else
:
return
torch
.
gather
(
input
,
dim
,
index
)
def
bipartite_soft_matching_random2d
(
metric
:
torch
.
Tensor
,
def
bipartite_soft_matching_random2d
(
metric
:
torch
.
Tensor
,
w
:
int
,
h
:
int
,
sx
:
int
,
sy
:
int
,
r
:
int
,
w
:
int
,
h
:
int
,
sx
:
int
,
sy
:
int
,
r
:
int
,
no_rand
:
bool
=
False
)
->
Tuple
[
Callable
,
Callable
]:
no_rand
:
bool
=
False
)
->
Tuple
[
Callable
,
Callable
]:
"""
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
Args:
- metric [B, N, C]: metric to use for similarity
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- w: image width in tokens
...
@@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
...
@@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if
r
<=
0
:
if
r
<=
0
:
return
do_nothing
,
do_nothing
return
do_nothing
,
do_nothing
gather
=
mps_gather_workaround
if
metric
.
device
.
type
==
"mps"
else
torch
.
gather
with
torch
.
no_grad
():
with
torch
.
no_grad
():
hsy
,
wsx
=
h
//
sy
,
w
//
sx
hsy
,
wsx
=
h
//
sy
,
w
//
sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
sy
*
sx
,
1
,
device
=
metric
.
device
)
if
no_rand
:
if
no_rand
:
rand_idx
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
1
,
1
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
rand_idx
=
torch
.
zeros
(
hsy
,
wsx
,
1
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
else
:
else
:
rand_idx
=
torch
.
randint
(
sy
*
sx
,
size
=
(
1
,
hsy
,
wsx
,
1
,
1
),
device
=
metric
.
device
)
rand_idx
=
torch
.
randint
(
sy
*
sx
,
size
=
(
hsy
,
wsx
,
1
),
device
=
metric
.
device
)
idx_buffer
.
scatter_
(
dim
=
3
,
index
=
rand_idx
,
src
=-
torch
.
ones_like
(
rand_idx
,
dtype
=
idx_buffer
.
dtype
))
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer
=
idx_buffer
.
view
(
1
,
hsy
,
wsx
,
sy
,
sx
,
1
).
transpose
(
2
,
3
).
reshape
(
1
,
N
,
1
)
idx_buffer_view
=
torch
.
zeros
(
hsy
,
wsx
,
sy
*
sx
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
rand_idx
=
idx_buffer
.
argsort
(
dim
=
1
)
idx_buffer_view
.
scatter_
(
dim
=
2
,
index
=
rand_idx
,
src
=-
torch
.
ones_like
(
rand_idx
,
dtype
=
rand_idx
.
dtype
))
idx_buffer_view
=
idx_buffer_view
.
view
(
hsy
,
wsx
,
sy
,
sx
).
transpose
(
1
,
2
).
reshape
(
hsy
*
sy
,
wsx
*
sx
)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if
(
hsy
*
sy
)
<
h
or
(
wsx
*
sx
)
<
w
:
idx_buffer
=
torch
.
zeros
(
h
,
w
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
idx_buffer
[:(
hsy
*
sy
),
:(
wsx
*
sx
)]
=
idx_buffer_view
else
:
idx_buffer
=
idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx
=
idx_buffer
.
reshape
(
1
,
-
1
,
1
).
argsort
(
dim
=
1
)
# We're finished with these
del
idx_buffer
,
idx_buffer_view
num_dst
=
int
((
1
/
(
sx
*
sy
))
*
N
)
# rand_idx is currently dst|src, so split them
num_dst
=
hsy
*
wsx
a_idx
=
rand_idx
[:,
num_dst
:,
:]
# src
a_idx
=
rand_idx
[:,
num_dst
:,
:]
# src
b_idx
=
rand_idx
[:,
:
num_dst
,
:]
# dst
b_idx
=
rand_idx
[:,
:
num_dst
,
:]
# dst
def
split
(
x
):
def
split
(
x
):
C
=
x
.
shape
[
-
1
]
C
=
x
.
shape
[
-
1
]
src
=
x
.
gather
(
dim
=
1
,
index
=
a_idx
.
expand
(
B
,
N
-
num_dst
,
C
))
src
=
gather
(
x
,
dim
=
1
,
index
=
a_idx
.
expand
(
B
,
N
-
num_dst
,
C
))
dst
=
x
.
gather
(
dim
=
1
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
C
))
dst
=
gather
(
x
,
dim
=
1
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
C
))
return
src
,
dst
return
src
,
dst
# Cosine similarity between A and B
metric
=
metric
/
metric
.
norm
(
dim
=-
1
,
keepdim
=
True
)
metric
=
metric
/
metric
.
norm
(
dim
=-
1
,
keepdim
=
True
)
a
,
b
=
split
(
metric
)
a
,
b
=
split
(
metric
)
scores
=
a
@
b
.
transpose
(
-
1
,
-
2
)
scores
=
a
@
b
.
transpose
(
-
1
,
-
2
)
...
@@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
...
@@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src
# Can't reduce more than the # tokens in src
r
=
min
(
a
.
shape
[
1
],
r
)
r
=
min
(
a
.
shape
[
1
],
r
)
# Find the most similar greedily
node_max
,
node_idx
=
scores
.
max
(
dim
=-
1
)
node_max
,
node_idx
=
scores
.
max
(
dim
=-
1
)
edge_idx
=
node_max
.
argsort
(
dim
=-
1
,
descending
=
True
)[...,
None
]
edge_idx
=
node_max
.
argsort
(
dim
=-
1
,
descending
=
True
)[...,
None
]
unm_idx
=
edge_idx
[...,
r
:,
:]
# Unmerged Tokens
unm_idx
=
edge_idx
[...,
r
:,
:]
# Unmerged Tokens
src_idx
=
edge_idx
[...,
:
r
,
:]
# Merged Tokens
src_idx
=
edge_idx
[...,
:
r
,
:]
# Merged Tokens
dst_idx
=
node_idx
[...,
None
]
.
gather
(
dim
=-
2
,
index
=
src_idx
)
dst_idx
=
gather
(
node_idx
[...,
None
]
,
dim
=-
2
,
index
=
src_idx
)
def
merge
(
x
:
torch
.
Tensor
,
mode
=
"mean"
)
->
torch
.
Tensor
:
def
merge
(
x
:
torch
.
Tensor
,
mode
=
"mean"
)
->
torch
.
Tensor
:
src
,
dst
=
split
(
x
)
src
,
dst
=
split
(
x
)
n
,
t1
,
c
=
src
.
shape
n
,
t1
,
c
=
src
.
shape
unm
=
src
.
gather
(
dim
=-
2
,
index
=
unm_idx
.
expand
(
n
,
t1
-
r
,
c
))
unm
=
gather
(
src
,
dim
=-
2
,
index
=
unm_idx
.
expand
(
n
,
t1
-
r
,
c
))
src
=
src
.
gather
(
dim
=-
2
,
index
=
src_idx
.
expand
(
n
,
r
,
c
))
src
=
gather
(
src
,
dim
=-
2
,
index
=
src_idx
.
expand
(
n
,
r
,
c
))
dst
=
dst
.
scatter_reduce
(
-
2
,
dst_idx
.
expand
(
n
,
r
,
c
),
src
,
reduce
=
mode
)
dst
=
dst
.
scatter_reduce
(
-
2
,
dst_idx
.
expand
(
n
,
r
,
c
),
src
,
reduce
=
mode
)
return
torch
.
cat
([
unm
,
dst
],
dim
=
1
)
return
torch
.
cat
([
unm
,
dst
],
dim
=
1
)
...
@@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
...
@@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm
,
dst
=
x
[...,
:
unm_len
,
:],
x
[...,
unm_len
:,
:]
unm
,
dst
=
x
[...,
:
unm_len
,
:],
x
[...,
unm_len
:,
:]
_
,
_
,
c
=
unm
.
shape
_
,
_
,
c
=
unm
.
shape
src
=
dst
.
gather
(
dim
=-
2
,
index
=
dst_idx
.
expand
(
B
,
r
,
c
))
src
=
gather
(
dst
,
dim
=-
2
,
index
=
dst_idx
.
expand
(
B
,
r
,
c
))
# Combine back to the original shape
# Combine back to the original shape
out
=
torch
.
zeros
(
B
,
N
,
c
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
out
=
torch
.
zeros
(
B
,
N
,
c
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
out
.
scatter_
(
dim
=-
2
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
c
),
src
=
dst
)
out
.
scatter_
(
dim
=-
2
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
c
),
src
=
dst
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
.
gather
(
dim
=
1
,
index
=
unm_idx
).
expand
(
B
,
unm_len
,
c
),
src
=
unm
)
out
.
scatter_
(
dim
=-
2
,
index
=
gather
(
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
,
dim
=
1
,
index
=
unm_idx
).
expand
(
B
,
unm_len
,
c
),
src
=
unm
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
.
gather
(
dim
=
1
,
index
=
src_idx
).
expand
(
B
,
r
,
c
),
src
=
src
)
out
.
scatter_
(
dim
=-
2
,
index
=
gather
(
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
)
,
dim
=
1
,
index
=
src_idx
).
expand
(
B
,
r
,
c
),
src
=
src
)
return
out
return
out
...
@@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
...
@@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def
get_functions
(
x
,
ratio
,
original_shape
):
def
get_functions
(
x
,
ratio
,
original_shape
):
b
,
c
,
original_h
,
original_w
=
original_shape
b
,
c
,
original_h
,
original_w
=
original_shape
original_tokens
=
original_h
*
original_w
original_tokens
=
original_h
*
original_w
downsample
=
int
(
math
.
sqrt
(
original_tokens
//
x
.
shape
[
1
]))
downsample
=
int
(
math
.
ceil
(
math
.
sqrt
(
original_tokens
//
x
.
shape
[
1
]))
)
stride_x
=
2
stride_x
=
2
stride_y
=
2
stride_y
=
2
max_downsample
=
1
max_downsample
=
1
if
downsample
<=
max_downsample
:
if
downsample
<=
max_downsample
:
w
=
original_w
/
/
downsample
w
=
int
(
math
.
ceil
(
original_w
/
downsample
))
h
=
original_h
/
/
downsample
h
=
int
(
math
.
ceil
(
original_h
/
downsample
))
r
=
int
(
x
.
shape
[
1
]
*
ratio
)
r
=
int
(
x
.
shape
[
1
]
*
ratio
)
no_rand
=
False
no_rand
=
False
m
,
u
=
bipartite_soft_matching_random2d
(
x
,
w
,
h
,
stride_x
,
stride_y
,
r
,
no_rand
)
m
,
u
=
bipartite_soft_matching_random2d
(
x
,
w
,
h
,
stride_x
,
stride_y
,
r
,
no_rand
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment