Unverified Commit 3099e0cc authored by vsuryamurthy's avatar vsuryamurthy Committed by GitHub
Browse files

Add missing type hints to anchor_utils (#6735)

* Use the variable name sizes instead of scales for consistency

* Add the missing type hints

* Restore the naming back to scales instead of sizes to avoid backwards incompatibility
parent 12adc542
......@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module):
aspect_ratios: List[float],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
) -> Tensor:
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
......@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module):
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self):
def num_anchors_per_location(self) -> List[int]:
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
......@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module):
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs
def num_anchors_per_location(self):
def num_anchors_per_location(self) -> List[int]:
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment