roi_align.py 924 Bytes
Newer Older
1
2
import torch.nn as nn
from torch.nn.modules.utils import _pair
Kai Chen's avatar
Kai Chen committed
3

4
from ..functions.roi_align import roi_align
Kai Chen's avatar
Kai Chen committed
5
6


7
8
9
10
11
12
13
class RoIAlign(nn.Module):

    def __init__(self,
                 out_size,
                 spatial_scale,
                 sample_num=0,
                 use_torchvision=False):
Kai Chen's avatar
Kai Chen committed
14
15
16
17
18
        super(RoIAlign, self).__init__()

        self.out_size = out_size
        self.spatial_scale = float(spatial_scale)
        self.sample_num = int(sample_num)
19
        self.use_torchvision = use_torchvision
Kai Chen's avatar
Kai Chen committed
20
21

    def forward(self, features, rois):
22
23
24
25
26
27
28
        if self.use_torchvision:
            from torchvision.ops import roi_align as tv_roi_align
            return tv_roi_align(features, rois, _pair(self.out_size),
                                self.spatial_scale, self.sample_num)
        else:
            return roi_align(features, rois, self.out_size, self.spatial_scale,
                             self.sample_num)