"...text-generation-inference.git" did not exist on "5b6b74e21d6cfa961afe3338fc5cfd45fa357b50"
Unverified Commit 86a14cba authored by Shubham Bhokare's avatar Shubham Bhokare Committed by GitHub
Browse files

Add topk min function for trace and onnx (#5310)



* Add topk minimizer function to _utils

* Apply ufmt formatting

* Apply min function for tracing and scripting

* Add type ignore to avoid cast

* fix flake

* Fix python_type_check
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 8097370e
...@@ -468,3 +468,27 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: ...@@ -468,3 +468,27 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
model.train() model.train()
return out_channels return out_channels
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
"""
ONNX spec requires the k-value to be less than or equal to the number of inputs along
provided dim. Certain models use the number of elements along a particular axis instead of K
if K exceeds the number of elements along that axis. Previously, python's min() function was
used to determine whether to use the provided k-value or the specified dim axis value.
However in cases where the model is being exported in tracing mode, python min() is
static causing the model to be traced incorrectly and eventually fail at the topk node.
In order to avoid this situation, in tracing mode, torch.min() is used instead.
Args:
input (Tensor): The orignal input tensor.
orig_kval (int): The provided k-value.
axis(int): Axis along which we retreive the input size.
Returns:
min_kval (Tensor): Appropriately selected k-value.
"""
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return min_kval # type: ignore[arg-type]
...@@ -501,7 +501,7 @@ class FCOS(nn.Module): ...@@ -501,7 +501,7 @@ class FCOS(nn.Module):
topk_idxs = torch.where(keep_idxs)[0] topk_idxs = torch.where(keep_idxs)[0]
# keep only topk scoring predictions # keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0)) num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
scores_per_level, idxs = scores_per_level.topk(num_topk) scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs] topk_idxs = topk_idxs[idxs]
......
...@@ -436,7 +436,7 @@ class RetinaNet(nn.Module): ...@@ -436,7 +436,7 @@ class RetinaNet(nn.Module):
topk_idxs = torch.where(keep_idxs)[0] topk_idxs = torch.where(keep_idxs)[0]
# keep only topk scoring predictions # keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0)) num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
scores_per_level, idxs = scores_per_level.topk(num_topk) scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs] topk_idxs = topk_idxs[idxs]
......
from typing import List, Optional, Dict, Tuple, cast from typing import List, Optional, Dict, Tuple
import torch import torch
import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
...@@ -13,17 +12,6 @@ from .anchor_utils import AnchorGenerator # noqa: 401 ...@@ -13,17 +12,6 @@ from .anchor_utils import AnchorGenerator # noqa: 401
from .image_list import ImageList from .image_list import ImageList
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
# for mypy we cast at runtime
return cast(int, num_anchors), cast(int, pre_nms_top_n)
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
Adds a simple RPN Head with classification and regression heads Adds a simple RPN Head with classification and regression heads
...@@ -206,11 +194,8 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -206,11 +194,8 @@ class RegionProposalNetwork(torch.nn.Module):
r = [] r = []
offset = 0 offset = 0
for ob in objectness.split(num_anchors_per_level, 1): for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
else:
num_anchors = ob.shape[1] num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset) r.append(top_n_idx + offset)
offset += num_anchors offset += num_anchors
......
...@@ -407,7 +407,7 @@ class SSD(nn.Module): ...@@ -407,7 +407,7 @@ class SSD(nn.Module):
box = boxes[keep_idxs] box = boxes[keep_idxs]
# keep only topk scoring predictions # keep only topk scoring predictions
num_topk = min(self.topk_candidates, score.size(0)) num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
score, idxs = score.topk(num_topk) score, idxs = score.topk(num_topk)
box = box[idxs] box = box[idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment