Unverified Commit 96aecd2d authored by Yonghye Kwon's avatar Yonghye Kwon Committed by GitHub
Browse files

call _upcast to consider overflow for computing giou loss (#5685)

* call _upcast to consider overflow

giou loss is weak at overflow problem because it computes area of box.

* cast datatype to float

* add ":" to if

* lint Test for membership should be 'not in' (E713)
parent da03f51a
import torch import torch
from torch import Tensor
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t
def generalized_box_iou_loss( def generalized_box_iou_loss(
...@@ -34,6 +42,8 @@ def generalized_box_iou_loss( ...@@ -34,6 +42,8 @@ def generalized_box_iou_loss(
https://arxiv.org/abs/1902.09630 https://arxiv.org/abs/1902.09630
""" """
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment