Unverified Commit 4508c84e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto][tests] Fix missing extra_dims in cxcywh (#6906)

parent cb4413a3
...@@ -373,8 +373,8 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt ...@@ -373,8 +373,8 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt
h = randint_with_tensor_bounds(1, height - y) h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h) parts = (x, y, w, h)
else: # format == features.BoundingBoxFormat.CXCYWH: else: # format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ()) cx = torch.randint(1, width - 1, extra_dims)
cy = torch.randint(1, height - 1, ()) cy = torch.randint(1, height - 1, extra_dims)
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment