roi_align.py 535 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.nn.modules.module import Module
from ..functions.roi_align import RoIAlignFunction


class RoIAlign(Module):

    def __init__(self, out_size, spatial_scale, sample_num=0):
        super(RoIAlign, self).__init__()

        self.out_size = out_size
        self.spatial_scale = float(spatial_scale)
        self.sample_num = int(sample_num)

    def forward(self, features, rois):
        return RoIAlignFunction.apply(features, rois, self.out_size,
                                      self.spatial_scale, self.sample_num)