Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a7501e13
"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "e1282c1420bd5f88c0bde29709698daf558166f7"
Unverified
Commit
a7501e13
authored
Aug 18, 2023
by
Philip Meier
Committed by
GitHub
Aug 18, 2023
Browse files
remove batch_dims from make bounding boxes and detection masks (#7855)
parent
59b27ed6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
10 deletions
+6
-10
test/common_utils.py
test/common_utils.py
+6
-10
No files found.
test/common_utils.py
View file @
a7501e13
...
@@ -406,26 +406,21 @@ def make_bounding_boxes(
...
@@ -406,26 +406,21 @@ def make_bounding_boxes(
canvas_size
=
DEFAULT_SIZE
,
canvas_size
=
DEFAULT_SIZE
,
*
,
*
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
batch_dims
=
(),
dtype
=
None
,
dtype
=
None
,
device
=
"cpu"
,
device
=
"cpu"
,
):
):
def
sample_position
(
values
,
max_value
):
def
sample_position
(
values
,
max_value
):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
# However, if we have batch_dims, we need tensors as limits.
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
flatten
().
tolist
()]).
reshape
(
values
.
shape
)
return
torch
.
stack
([
torch
.
randint
(
max_value
-
v
,
())
for
v
in
values
.
tolist
()]
)
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoints
.
BoundingBoxFormat
[
format
]
format
=
datapoints
.
BoundingBoxFormat
[
format
]
dtype
=
dtype
or
torch
.
float32
dtype
=
dtype
or
torch
.
float32
if
any
(
dim
==
0
for
dim
in
batch_dims
):
num_objects
=
1
return
datapoints
.
BoundingBoxes
(
h
,
w
=
[
torch
.
randint
(
1
,
c
,
(
num_objects
,))
for
c
in
canvas_size
]
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
canvas_size
=
canvas_size
)
h
,
w
=
[
torch
.
randint
(
1
,
c
,
batch_dims
)
for
c
in
canvas_size
]
y
=
sample_position
(
h
,
canvas_size
[
0
])
y
=
sample_position
(
h
,
canvas_size
[
0
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
x
=
sample_position
(
w
,
canvas_size
[
1
])
...
@@ -448,11 +443,12 @@ def make_bounding_boxes(
...
@@ -448,11 +443,12 @@ def make_bounding_boxes(
)
)
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
num_objects
=
5
,
batch_dims
=
(),
dtype
=
None
,
device
=
"cpu"
):
def
make_detection_mask
(
size
=
DEFAULT_SIZE
,
*
,
dtype
=
None
,
device
=
"cpu"
):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects
=
1
return
datapoints
.
Mask
(
return
datapoints
.
Mask
(
torch
.
testing
.
make_tensor
(
torch
.
testing
.
make_tensor
(
(
*
batch_dims
,
num_objects
,
*
size
),
(
num_objects
,
*
size
),
low
=
0
,
low
=
0
,
high
=
2
,
high
=
2
,
dtype
=
dtype
or
torch
.
bool
,
dtype
=
dtype
or
torch
.
bool
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment