import torch.nn as nn from torch.nn.modules.utils import _pair from ..functions.roi_align import roi_align class RoIAlign(nn.Module): def __init__(self, out_size, spatial_scale, sample_num=0, use_torchvision=False): super(RoIAlign, self).__init__() self.out_size = out_size self.spatial_scale = float(spatial_scale) self.sample_num = int(sample_num) self.use_torchvision = use_torchvision def forward(self, features, rois): if self.use_torchvision: from torchvision.ops import roi_align as tv_roi_align return tv_roi_align(features, rois, _pair(self.out_size), self.spatial_scale, self.sample_num) else: return roi_align(features, rois, self.out_size, self.spatial_scale, self.sample_num)