Unverified Commit 48441cc5 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Refactor grid default boxes with torch meshgrid (#3799)



* Refactor grid default boxes with torch.meshgrid

* Fix torch jit tracing

* Only doing the list multiplication once
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Make grid_default_box private as suggested
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Replace list multiplication with torch.repeat

* Move the clipping into _grid_default_boxes to accelerate
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 5dd7dfe3
...@@ -170,26 +170,59 @@ class DefaultBoxGenerator(nn.Module): ...@@ -170,26 +170,59 @@ class DefaultBoxGenerator(nn.Module):
else: else:
self.scales = scales self.scales = scales
self._wh_pairs = [] self._wh_pairs = self._generate_wh_pairs(num_outputs)
def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")) -> List[Tensor]:
_wh_pairs: List[Tensor] = []
for k in range(num_outputs): for k in range(num_outputs):
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
s_k = self.scales[k] s_k = self.scales[k]
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1]) s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
wh_pairs = [(s_k, s_k), (s_prime_k, s_prime_k)] wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
# Adding 2 pairs for each aspect ratio of the feature map k # Adding 2 pairs for each aspect ratio of the feature map k
for ar in self.aspect_ratios[k]: for ar in self.aspect_ratios[k]:
sq_ar = math.sqrt(ar) sq_ar = math.sqrt(ar)
w = self.scales[k] * sq_ar w = self.scales[k] * sq_ar
h = self.scales[k] / sq_ar h = self.scales[k] / sq_ar
wh_pairs.extend([(w, h), (h, w)]) wh_pairs.extend([[w, h], [h, w]])
self._wh_pairs.append(wh_pairs) _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs
def num_anchors_per_location(self): def num_anchors_per_location(self):
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map. # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios] return [2 + 2 * len(r) for r in self.aspect_ratios]
# Default Boxes calculation based on page 6 of SSD paper
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int],
dtype: torch.dtype = torch.float32) -> Tensor:
default_boxes = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
if self.steps is not None:
x_f_k, y_f_k = [img_shape / self.steps[k] for img_shape in image_size]
else:
y_f_k, x_f_k = f_k
shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
default_box = torch.cat((shifts, wh_pairs), dim=1)
default_boxes.append(default_box)
return torch.cat(default_boxes, dim=0)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + '('
s += 'aspect_ratios={aspect_ratios}' s += 'aspect_ratios={aspect_ratios}'
...@@ -203,30 +236,12 @@ class DefaultBoxGenerator(nn.Module): ...@@ -203,30 +236,12 @@ class DefaultBoxGenerator(nn.Module):
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
# Default Boxes calculation based on page 6 of SSD paper default_boxes = default_boxes.to(device)
default_boxes: List[List[float]] = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
for j in range(f_k[0]):
if self.steps is not None:
y_f_k = image_size[1] / self.steps[k]
else:
y_f_k = float(f_k[0])
cy = (j + 0.5) / y_f_k
for i in range(f_k[1]):
if self.steps is not None:
x_f_k = image_size[0] / self.steps[k]
else:
x_f_k = float(f_k[1])
cx = (i + 0.5) / x_f_k
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])
dboxes = [] dboxes = []
for _ in image_list.image_sizes: for _ in image_list.image_sizes:
dboxes_in_image = torch.tensor(default_boxes, dtype=dtype, device=device) dboxes_in_image = default_boxes
if self.clip:
dboxes_in_image.clamp_(min=0, max=1)
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1) dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
dboxes_in_image[:, 0::2] *= image_size[1] dboxes_in_image[:, 0::2] *= image_size[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment