nodes_tomesd.py 6.47 KB
Newer Older
1
#Taken from: https://github.com/dbolya/tomesd
2
3
4
5
6
7
8
9
10

import torch
from typing import Tuple, Callable
import math

def do_nothing(x: torch.Tensor, mode:str=None):
    return x


11
12
13
14
15
16
17
18
19
20
21
def mps_gather_workaround(input, dim, index):
    if input.shape[-1] == 1:
        return torch.gather(
            input.unsqueeze(-1),
            dim - 1 if dim < 0 else dim,
            index.unsqueeze(-1)
        ).squeeze(-1)
    else:
        return torch.gather(input, dim, index)


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def bipartite_soft_matching_random2d(metric: torch.Tensor,
                                     w: int, h: int, sx: int, sy: int, r: int,
                                     no_rand: bool = False) -> Tuple[Callable, Callable]:
    """
    Partitions the tokens into src and dst and merges r tokens from src to dst.
    Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
    Args:
     - metric [B, N, C]: metric to use for similarity
     - w: image width in tokens
     - h: image height in tokens
     - sx: stride in the x dimension for dst, must divide w
     - sy: stride in the y dimension for dst, must divide h
     - r: number of tokens to remove (by merging)
     - no_rand: if true, disable randomness (use top left corner only)
    """
    B, N, _ = metric.shape

BlenderNeko's avatar
BlenderNeko committed
39
    if r <= 0 or w == 1 or h == 1:
40
        return do_nothing, do_nothing
41
42

    gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
43
44
45
46
47
48
49
    
    with torch.no_grad():
        
        hsy, wsx = h // sy, w // sx

        # For each sy by sx kernel, randomly assign one token to be dst and the rest src
        if no_rand:
50
            rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
51
        else:
52
            rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
53
        
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
        idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
        idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
        idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)

        # Image is not divisible by sx or sy so we need to move it into a new buffer
        if (hsy * sy) < h or (wsx * sx) < w:
            idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
            idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
        else:
            idx_buffer = idx_buffer_view

        # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
        rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)

        # We're finished with these
        del idx_buffer, idx_buffer_view
71

72
73
        # rand_idx is currently dst|src, so split them
        num_dst = hsy * wsx
74
75
76
77
78
        a_idx = rand_idx[:, num_dst:, :] # src
        b_idx = rand_idx[:, :num_dst, :] # dst

        def split(x):
            C = x.shape[-1]
79
80
            src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
            dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
81
82
            return src, dst

83
        # Cosine similarity between A and B
84
85
86
87
88
89
90
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = split(metric)
        scores = a @ b.transpose(-1, -2)

        # Can't reduce more than the # tokens in src
        r = min(a.shape[1], r)

91
        # Find the most similar greedily
92
93
94
95
96
        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
97
        dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
98
99
100
101
102

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = split(x)
        n, t1, c = src.shape
        
103
104
        unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
105
106
107
108
109
110
111
112
113
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        _, _, c = unm.shape

114
        src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
115
116
117
118

        # Combine back to the original shape
        out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
        out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
119
120
        out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
121
122
123
124
125
126
127
128
129

        return out

    return merge, unmerge


def get_functions(x, ratio, original_shape):
    b, c, original_h, original_w = original_shape
    original_tokens = original_h * original_w
130
    downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
131
132
133
134
135
    stride_x = 2
    stride_y = 2
    max_downsample = 1

    if downsample <= max_downsample:
136
137
        w = int(math.ceil(original_w / downsample))
        h = int(math.ceil(original_h / downsample))
138
        r = int(x.shape[1] * ratio)
139
        no_rand = False
140
141
142
143
144
        m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
        return m, u

    nothing = lambda y: y
    return nothing, nothing
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177



class TomePatchModel:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "_for_testing"

    def patch(self, model, ratio):
        self.u = None
        def tomesd_m(q, k, v, extra_options):
            #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
            #however from my basic testing it seems that using q instead gives better results
            m, self.u = get_functions(q, ratio, extra_options["original_shape"])
            return m(q), k, v
        def tomesd_u(n, extra_options):
            return self.u(n)

        m = model.clone()
        m.set_model_attn1_patch(tomesd_m)
        m.set_model_attn1_output_patch(tomesd_u)
        return (m, )


NODE_CLASS_MAPPINGS = {
    "TomePatchModel": TomePatchModel,
}