Unverified Commit 4b2ad55f authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Refactor poolers (#4951)

* refactoring methods from MultiScaleRoIAlign
parent 9841a907
...@@ -83,6 +83,29 @@ class LevelMapper: ...@@ -83,6 +83,29 @@ class LevelMapper:
return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64) return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
def _convert_to_roi_format(boxes: List[Tensor]) -> Tensor:
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) for i, b in enumerate(boxes)],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales: List[float] = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
class MultiScaleRoIAlign(nn.Module): class MultiScaleRoIAlign(nn.Module):
""" """
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN. Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
...@@ -142,30 +165,6 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -142,30 +165,6 @@ class MultiScaleRoIAlign(nn.Module):
self.canonical_scale = canonical_scale self.canonical_scale = canonical_scale
self.canonical_level = canonical_level self.canonical_level = canonical_level
def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[
torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
for i, b in enumerate(boxes)
],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales: List[float] = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
def setup_scales( def setup_scales(
self, self,
features: List[Tensor], features: List[Tensor],
...@@ -179,7 +178,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -179,7 +178,7 @@ class MultiScaleRoIAlign(nn.Module):
max_y = max(shape[1], max_y) max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y) original_input_shape = (max_x, max_y)
scales = [self.infer_scale(feat, original_input_shape) for feat in features] scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always # get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level. # downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
...@@ -216,7 +215,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -216,7 +215,7 @@ class MultiScaleRoIAlign(nn.Module):
if k in self.featmap_names: if k in self.featmap_names:
x_filtered.append(v) x_filtered.append(v)
num_levels = len(x_filtered) num_levels = len(x_filtered)
rois = self.convert_to_roi_format(boxes) rois = _convert_to_roi_format(boxes)
if self.scales is None: if self.scales is None:
self.setup_scales(x_filtered, image_shapes) self.setup_scales(x_filtered, image_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment