Unverified Commit 467b4883 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

add torchvision roi_align with aligned=True (#581)

* add torchvision roi_align with aligned=True

* fix for lint test

* fix for lint test2

* format use yapf
parent 6b52e9b5
...@@ -152,8 +152,6 @@ class RoIAlign(nn.Module): ...@@ -152,8 +152,6 @@ class RoIAlign(nn.Module):
self.pool_mode = pool_mode self.pool_mode = pool_mode
self.aligned = aligned self.aligned = aligned
self.use_torchvision = use_torchvision self.use_torchvision = use_torchvision
assert not (use_torchvision and
aligned), 'Torchvision does not support aligned RoIAlgin'
def forward(self, input, rois): def forward(self, input, rois):
""" """
...@@ -164,6 +162,14 @@ class RoIAlign(nn.Module): ...@@ -164,6 +162,14 @@ class RoIAlign(nn.Module):
""" """
if self.use_torchvision: if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align from torchvision.ops import roi_align as tv_roi_align
if 'aligned' in tv_roi_align.__code__.co_varnames:
return tv_roi_align(input, rois, self.output_size,
self.spatial_scale, self.sampling_ratio,
self.aligned)
else:
if self.aligned:
rois -= rois.new_tensor([0.] +
[0.5 / self.spatial_scale] * 4)
return tv_roi_align(input, rois, self.output_size, return tv_roi_align(input, rois, self.output_size,
self.spatial_scale, self.sampling_ratio) self.spatial_scale, self.sampling_ratio)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment