roi_pool.py 399 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch.nn.modules.module import Module
from ..functions.roi_pool import roi_pool


class RoIPool(Module):

    def __init__(self, out_size, spatial_scale):
        super(RoIPool, self).__init__()

        self.out_size = out_size
        self.spatial_scale = float(spatial_scale)

    def forward(self, features, rois):
        return roi_pool(features, rois, self.out_size, self.spatial_scale)