merge.py 1.13 KB
Newer Older
Melos's avatar
Melos committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


import math
from typing import Callable, Tuple

import torch


def self_soft_matching(
    metric: torch.Tensor,
    r: int,):

    t = metric.shape[1]
    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., :, :], metric[..., :, :]
        scores = a @ b.transpose(-1, -2) # a_lxb_l
        b,_,_ = scores.shape
        scores_diag = torch.tril(torch.ones(t,t))*2
        scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)

        scores = scores-scores_diag
        node_max, node_idx = scores.max(dim=-1) # a中最相似的点
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序

        unm_idx = edge_idx[..., t-r:, :]  # Unmerged Tokens # 后面的就是不merge的

    def merge(src: torch.Tensor) -> torch.Tensor:
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n,  r, c))
        unm_idx_new = unm_idx
        all_idx = unm_idx_new
        all_max,all_idx_idx = torch.sort(all_idx,dim=1)
        return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))

    return merge