test_gap.py 742 Bytes
Newer Older
dengjb's avatar
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmdet.models import GlobalAveragePooling


class TestGlobalAveragePooling(TestCase):

    def test_forward(self):
        inputs = torch.rand(32, 128, 14, 14)

        # test AdaptiveAvgPool2d
        neck = GlobalAveragePooling()
        outputs = neck(inputs)
        assert outputs.shape == (32, 128)

        # test kernel_size
        neck = GlobalAveragePooling(kernel_size=7)
        outputs = neck(inputs)
        assert outputs.shape == (32, 128 * 2 * 2)

        # test kenel_size and stride
        neck = GlobalAveragePooling(kernel_size=7, stride=2)
        outputs = neck(inputs)
        assert outputs.shape == (32, 128 * 4 * 4)