"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "e1282c1420bd5f88c0bde29709698daf558166f7"
Unverified Commit a7501e13 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove batch_dims from make bounding boxes and detection masks (#7855)

parent 59b27ed6
...@@ -406,26 +406,21 @@ def make_bounding_boxes( ...@@ -406,26 +406,21 @@ def make_bounding_boxes(
canvas_size=DEFAULT_SIZE, canvas_size=DEFAULT_SIZE,
*, *,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
batch_dims=(),
dtype=None, dtype=None,
device="cpu", device="cpu",
): ):
def sample_position(values, max_value): def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits. # However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape) return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
dtype = dtype or torch.float32 dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims): num_objects = 1
return datapoints.BoundingBoxes( h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
y = sample_position(h, canvas_size[0]) y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1]) x = sample_position(w, canvas_size[1])
...@@ -448,11 +443,12 @@ def make_bounding_boxes( ...@@ -448,11 +443,12 @@ def make_bounding_boxes(
) )
def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"): def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks""" """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects = 1
return datapoints.Mask( return datapoints.Mask(
torch.testing.make_tensor( torch.testing.make_tensor(
(*batch_dims, num_objects, *size), (num_objects, *size),
low=0, low=0,
high=2, high=2,
dtype=dtype or torch.bool, dtype=dtype or torch.bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment