roi_pool.py 728 Bytes
Newer Older
1
2
3
import torch.nn as nn
from torch.nn.modules.utils import _pair

Kai Chen's avatar
Kai Chen committed
4
5
6
from ..functions.roi_pool import roi_pool


7
class RoIPool(nn.Module):
Kai Chen's avatar
Kai Chen committed
8

9
    def __init__(self, out_size, spatial_scale, use_torchvision=False):
Kai Chen's avatar
Kai Chen committed
10
11
12
13
        super(RoIPool, self).__init__()

        self.out_size = out_size
        self.spatial_scale = float(spatial_scale)
14
        self.use_torchvision = use_torchvision
Kai Chen's avatar
Kai Chen committed
15
16

    def forward(self, features, rois):
17
18
19
20
21
22
        if self.use_torchvision:
            from torchvision.ops import roi_pool as tv_roi_pool
            return tv_roi_pool(features, rois, _pair(self.out_size),
                               self.spatial_scale)
        else:
            return roi_pool(features, rois, self.out_size, self.spatial_scale)