"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "2e905c8bd4d0ea6b2c3ea3dcbdcefaab163905c1"
Unverified Commit e9fca7bd authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Use onnx function only in tracing mode (#5468)



* Use onnx function only in tracing mode

* Add missing import

* Address review comments

* Fix type annotation

* Ignore return type error

* Remove unused import

* Add fake cast

* Fix mypy error

* Fix mypy error

* Update torchvision/models/detection/_utils.py
Co-authored-by: default avatarShubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com>

* _fake_cast_onnx approach
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarShubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com>
parent e92f1195
...@@ -470,7 +470,12 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: ...@@ -470,7 +470,12 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
return out_channels return out_channels
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor: @torch.jit.unused
def _fake_cast_onnx(v: Tensor) -> int:
return v # type: ignore[return-value]
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
""" """
ONNX spec requires the k-value to be less than or equal to the number of inputs along ONNX spec requires the k-value to be less than or equal to the number of inputs along
provided dim. Certain models use the number of elements along a particular axis instead of K provided dim. Certain models use the number of elements along a particular axis instead of K
...@@ -487,8 +492,10 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor: ...@@ -487,8 +492,10 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
axis(int): Axis along which we retreive the input size. axis(int): Axis along which we retreive the input size.
Returns: Returns:
min_kval (Tensor): Appropriately selected k-value. min_kval (int): Appropriately selected k-value.
""" """
if not torch.jit.is_tracing():
return min(orig_kval, input.size(axis))
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return min_kval # type: ignore[arg-type] return _fake_cast_onnx(min_kval)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment