position_encoding.py 5.31 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import math
import functools
from typing import Tuple, Union

import torch
from torch import Tensor, nn


class PositionEmbeddingSine(nn.Module):
    """Sinusoidal position embedding used in DETR model. See `End-to-End Object Detection 
    with Transformers <https://arxiv.org/pdf/2005.12872>`_ for more details.

    :param num_pos_feats: The feature dimension for each position along x-axis or y-axis.
        The final returned dimension for each position is 2 times of the input value, 
        defaults to 64
    :param temperature: The temperature used for scaling the position embedding, defaults to 10000
    :param normalize: Whether to normalize the position embedding, defaults to False
    :param scale: A scale factor that scales the position embedding, which is used only when
        `normalize` is True, defaults to 2*math.pi
    :param eps: A value added to the denominator for numerical stability, defaults to 1e-6
    :param offset: An offset added to embed, defaults to 0.0
    """
    def __init__(
        self,
        num_pos_feats=64,
        temperature: Union[int, Tuple[int, int]] = 10000,
        normalize=False,
        scale=2 * math.pi,
        eps=1e-6,
        offset=0.0,
    ):
        super().__init__()
        dim_t = 2 * torch.arange(num_pos_feats).div(2, rounding_mode="floor") / num_pos_feats
        if isinstance(temperature, int):
            dim_tx = dim_ty = temperature**dim_t
        else:
            assert len(temperature) == 2, "Only support two elements as (t_x, t_y) in temperature"
            dim_tx, dim_ty = [t**dim_t for t in temperature]
        self.register_buffer("dim_tx", dim_tx)
        self.register_buffer("dim_ty", dim_ty)

        self.normalize = normalize
        self.scale = scale
        self.eps = eps
        self.offset = offset

    def forward(self, mask: Tensor):
        mask = mask.to(torch.int)
        not_mask = 1 - mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
            x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
        else:
            # RT-DETR uses unnormalized encoding with index from 0
            y_embed = y_embed + self.offset
            x_embed = x_embed + self.offset

        pos_x = x_embed[:, :, :, None] / self.dim_tx
        pos_y = y_embed[:, :, :, None] / self.dim_ty
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class PositionEmbeddingLearned(nn.Module):
    """Absolute pos embedding, learned."""
    def __init__(self, num_embeddings: int = 50, num_pos_feats: int = 256):
        super().__init__()
        self.row_embed = nn.Embedding(num_embeddings, num_pos_feats)
        self.col_embed = nn.Embedding(num_embeddings, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, mask: Tensor):
        h, w = mask.shape[-2:]
        i = torch.arange(w, device=mask.device)
        j = torch.arange(h, device=mask.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = (
            torch.cat(
                [
                    x_emb.unsqueeze(0).repeat(h, 1, 1),
                    y_emb.unsqueeze(1).repeat(1, w, 1),
                ],
                dim=-1,
            ).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
        )
        return pos


@functools.lru_cache  # use lru_cache to avoid redundant calculation for dim_t
def get_dim_t(num_pos_feats: int, temperature: int, device: torch.device):
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
    dim_t = temperature**(2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
    return dim_t


def get_sine_pos_embed(
    pos_tensor: Tensor,
    num_pos_feats: int = 128,
    temperature: int = 10000,
    scale: float = 2 * math.pi,
    exchange_xy: bool = True,
) -> Tensor:
    """Generate sine position embedding for a position tensor

    :param pos_tensor: shape as (..., 2*n).
    :param num_pos_feats: projected shape for each float in the tensor, defaults to 128
    :param temperature: the temperature used for scaling the position embedding, defaults to 10000
    :param exchange_xy: exchange pos x and pos. For example, 
        input tensor is [x, y], the results will be [pos(y), pos(x)], defaults to True
    :return: position embedding with shape (None, n * num_pos_feats)
    """
    dim_t = get_dim_t(num_pos_feats, temperature, pos_tensor.device)

    pos_res = pos_tensor.unsqueeze(-1) * scale / dim_t
    pos_res = torch.stack((pos_res[..., 0::2].sin(), pos_res[..., 1::2].cos()), dim=-1).flatten(-2)
    if exchange_xy:
        index = torch.cat([
            torch.arange(1, -1, -1, device=pos_res.device),
            torch.arange(2, pos_res.shape[-2], device=pos_res.device),
        ])
        pos_res = torch.index_select(pos_res, -2, index)
    pos_res = pos_res.view(*pos_tensor.shape[:-1], -1)
    return pos_res