"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "196835695ed6fa3ec53b888088d9d5581e8f8e94"
Unverified Commit 3099e0cc authored by vsuryamurthy's avatar vsuryamurthy Committed by GitHub
Browse files

Add missing type hints to anchor_utils (#6735)

* Use the variable name sizes instead of scales for consistency

* Add the missing type hints

* Restore the naming back to scales instead of sizes to avoid backwards incompatibility
parent 12adc542
...@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module): ...@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module):
aspect_ratios: List[float], aspect_ratios: List[float],
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
): ) -> Tensor:
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module): ...@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module):
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self): def num_anchors_per_location(self) -> List[int]:
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
...@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module): ...@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module):
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device)) _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs return _wh_pairs
def num_anchors_per_location(self): def num_anchors_per_location(self) -> List[int]:
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map. # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios] return [2 + 2 * len(r) for r in self.aspect_ratios]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment