Unverified Commit f04e9cb9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix NMS and IoU overflows for fp16 (#3383)

* Replace type T with accumulator.

* Upcast tensors of box ops to avoid overflow in multiplications.
parent af97ec2f
...@@ -449,6 +449,18 @@ class NMSTester(unittest.TestCase): ...@@ -449,6 +449,18 @@ class NMSTester(unittest.TestCase):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype) self.test_nms_cuda(dtype=dtype)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
class DeformConvTester(OpTester, unittest.TestCase): class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
...@@ -829,48 +841,75 @@ class BoxTester(unittest.TestCase): ...@@ -829,48 +841,75 @@ class BoxTester(unittest.TestCase):
class BoxAreaTester(unittest.TestCase): class BoxAreaTester(unittest.TestCase):
def test_box_area(self): def test_box_area(self):
# A bounding box of area 10000 and a degenerate case def area_check(box, expected, tolerance=1e-4):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) out = ops.box_area(box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
# Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
expected = torch.tensor([10000, 0]) expected = torch.tensor([10000, 0])
calc_area = ops.box_area(box_tensor) area_check(box_tensor, expected)
assert calc_area.size() == torch.Size([2])
assert calc_area.dtype == box_tensor.dtype # Check for float32 and float64 boxes
assert torch.all(torch.eq(calc_area, expected)).item() is True for dtype in [torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
area_check(box_tensor, expected, tolerance=0.05)
# Check for float16 box
box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5],
[285.25, 188.75, 1192.0, 851.0],
[279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16)
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
area_check(box_tensor, expected)
class BoxIouTester(unittest.TestCase): class BoxIouTester(unittest.TestCase):
def test_iou(self): def test_iou(self):
# Boxes to test Iou def iou_check(box, expected, tolerance=1e-4):
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) out = ops.box_iou(box, box)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
# Expected IoU matrix for these boxes
# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
iou_check(box, expected)
out = ops.box_iou(boxes1, boxes2) # Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]:
# Check if all elements of tensor are as expected. box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
assert out.size() == torch.Size([3, 3]) [285.1472, 188.7374, 1192.4984, 851.0669],
tolerance = 1e-4 [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
assert ((out - expected).abs().max() < tolerance).item() is True expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
class GenBoxIouTester(unittest.TestCase): class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self): def test_gen_iou(self):
# Test Generalized IoU def gen_iou_check(box, expected, tolerance=1e-4):
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) out = ops.generalized_box_iou(box, box)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
# Expected gIoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], # Check for int boxes
[-0.7778, -0.8611, 1.0]]) for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
out = ops.generalized_box_iou(boxes1, boxes2) expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
gen_iou_check(box, expected)
# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3]) # Check for float boxes
tolerance = 1e-4 for dtype in [torch.float16, torch.float32, torch.float64]:
assert ((out - expected).abs().max() < tolerance).item() is True box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
if __name__ == '__main__': if __name__ == '__main__':
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/library.h> #include <torch/library.h>
...@@ -20,9 +21,10 @@ __device__ inline bool devIoU( ...@@ -20,9 +21,10 @@ __device__ inline bool devIoU(
T left = max(a[0], b[0]), right = min(a[2], b[2]); T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]); T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0); T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height; using acc_T = at::acc_type<T, /*is_cuda=*/true>;
T Sa = (a[2] - a[0]) * (a[3] - a[1]); acc_T interS = (acc_T)width * height;
T Sb = (b[2] - b[0]) * (b[3] - b[1]); acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold; return (interS / (Sa + Sb - interS)) > threshold;
} }
......
...@@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ...@@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes return boxes
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
def box_area(boxes: Tensor) -> Tensor: def box_area(boxes: Tensor) -> Tensor:
""" """
Computes the area of a set of bounding boxes, which are specified by its Computes the area of a set of bounding boxes, which are specified by its
...@@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor: ...@@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor:
Returns: Returns:
area (Tensor[N]): area for each box area (Tensor[N]): area for each box
""" """
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
...@@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: ...@@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter union = area1[:, None] + area2 - inter
...@@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ...@@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
whi = (rbi - lti).clamp(min=0) # [N,M,2] whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1] areai = whi[:, :, 0] * whi[:, :, 1]
return iou - (areai - union) / areai return iou - (areai - union) / areai
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment