test_position_encoding.py 831 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmpretrain.models.utils import (ConditionalPositionEncoding,
                                     RotaryEmbeddingFast)


def test_conditional_position_encoding_module():
    CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
    outs = CPE(torch.randn(1, 3136, 32), (56, 56))
    assert outs.shape == torch.Size([1, 784, 32])


def test_rotary_embedding_fast_module():
    RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24)
    outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24))
    assert outs.shape == torch.Size([1, 2, 24 * 24, 64])

    RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20))
    outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20))
    assert outs.shape == torch.Size([1, 2, 14 * 20, 64])