Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2c1cdd15
Commit
2c1cdd15
authored
Jun 24, 2022
by
charlie
Browse files
Use iterators and fix basic_iota_iterator
parent
80f7c2b7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
78 deletions
+93
-78
src/include/migraphx/iota_iterator.hpp
src/include/migraphx/iota_iterator.hpp
+2
-1
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+91
-77
No files found.
src/include/migraphx/iota_iterator.hpp
View file @
2c1cdd15
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
index
--
;
index
--
;
return
it
;
return
it
;
}
}
// TODO: operator->
reference
operator
*
()
const
{
return
f
(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
pointer
operator
->
()
const
{
return
&
f
(
index
);
}
reference
operator
[](
int
n
)
const
{
return
f
(
index
+
n
);
}
};
};
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
2c1cdd15
...
@@ -81,8 +81,8 @@ struct nonmaxsuppression
...
@@ -81,8 +81,8 @@ struct nonmaxsuppression
struct
box
struct
box
{
{
std
::
array
<
float
,
2
>
x
;
std
::
array
<
double
,
2
>
x
;
std
::
array
<
float
,
2
>
y
;
std
::
array
<
double
,
2
>
y
;
void
sort
()
void
sort
()
{
{
...
@@ -90,9 +90,9 @@ struct nonmaxsuppression
...
@@ -90,9 +90,9 @@ struct nonmaxsuppression
std
::
sort
(
y
.
begin
(),
y
.
end
());
std
::
sort
(
y
.
begin
(),
y
.
end
());
}
}
std
::
array
<
float
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
float
area
()
const
double
area
()
const
{
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
...
@@ -101,29 +101,29 @@ struct nonmaxsuppression
...
@@ -101,29 +101,29 @@ struct nonmaxsuppression
};
};
template
<
class
T
>
template
<
class
T
>
box
batch_box
(
const
T
&
boxes
,
std
::
size_t
box_ind
,
std
::
size_t
box_idx
)
const
box
batch_box
(
T
boxes
,
std
::
size_t
box_idx
)
const
{
{
box
result
{};
box
result
{};
auto
start
=
box
_ind
+
4
*
box_idx
;
auto
start
=
box
es
+
4
*
box_idx
;
if
(
center_point_box
)
if
(
center_point_box
)
{
{
float
half_width
=
boxes
[
start
+
2
]
/
2.0
;
double
half_width
=
start
[
2
]
/
2.0
;
float
half_height
=
boxes
[
start
+
3
]
/
2.0
;
double
half_height
=
start
[
3
]
/
2.0
;
float
x_center
=
boxes
[
start
+
0
];
double
x_center
=
start
[
0
];
float
y_center
=
boxes
[
start
+
1
];
double
y_center
=
start
[
1
];
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
}
}
else
else
{
{
result
.
x
=
{
static_cast
<
float
>
(
boxes
[
start
+
1
]),
static_cast
<
float
>
(
boxes
[
start
+
3
])};
result
.
x
=
{
static_cast
<
double
>
(
start
[
1
]),
static_cast
<
double
>
(
start
[
3
])};
result
.
y
=
{
static_cast
<
float
>
(
boxes
[
start
+
0
]),
static_cast
<
float
>
(
boxes
[
start
+
2
])};
result
.
y
=
{
static_cast
<
double
>
(
start
[
0
]),
static_cast
<
double
>
(
start
[
2
])};
}
}
return
result
;
return
result
;
}
}
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
double
iou_threshold
)
const
{
{
b1
.
sort
();
b1
.
sort
();
b2
.
sort
();
b2
.
sort
();
...
@@ -135,7 +135,7 @@ struct nonmaxsuppression
...
@@ -135,7 +135,7 @@ struct nonmaxsuppression
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
}
std
::
vector
<
std
::
array
<
float
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
}))
...
@@ -143,33 +143,33 @@ struct nonmaxsuppression
...
@@ -143,33 +143,33 @@ struct nonmaxsuppression
return
false
;
return
false
;
}
}
const
float
area1
=
b1
.
area
();
const
double
area1
=
b1
.
area
();
const
float
area2
=
b2
.
area
();
const
double
area2
=
b2
.
area
();
const
float
intersection_area
=
intersection
.
area
();
const
double
intersection_area
=
intersection
.
area
();
const
float
union_area
=
area1
+
area2
-
intersection_area
;
const
double
union_area
=
area1
+
area2
-
intersection_area
;
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
{
{
return
false
;
return
false
;
}
}
const
float
intersection_over_union
=
intersection_area
/
union_area
;
const
double
intersection_over_union
=
intersection_area
/
union_area
;
return
intersection_over_union
>
iou_threshold
;
return
intersection_over_union
>
iou_threshold
;
}
}
// filter boxes below score_threshold
// filter boxes below score_threshold
template
<
class
T
>
template
<
class
T
>
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
filter_boxes_by_score
(
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
filter_boxes_by_score
(
T
scores
,
std
::
size_t
score_offset_ind
,
std
::
size_t
num_boxes
,
float
score_threshold
)
const
T
scores
_start
,
std
::
size_t
num_boxes
,
double
score_threshold
)
const
{
{
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
boxes_heap
;
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
boxes_heap
;
auto
insert_to_boxes_heap
=
auto
insert_to_boxes_heap
=
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
int64_t
box_idx
=
0
;
int64_t
box_idx
=
0
;
transform_if
(
transform_if
(
scores
.
begin
()
+
score_offset_ind
,
scores
_start
,
scores
.
begin
()
+
score_offset_ind
+
num_boxes
,
scores
_start
+
num_boxes
,
insert_to_boxes_heap
,
insert_to_boxes_heap
,
[
&
](
auto
sc
)
{
[
&
](
auto
sc
)
{
box_idx
++
;
box_idx
++
;
...
@@ -179,6 +179,47 @@ struct nonmaxsuppression
...
@@ -179,6 +179,47 @@ struct nonmaxsuppression
return
boxes_heap
;
return
boxes_heap
;
}
}
template
<
class
H
,
class
S
>
void
select_boxes
(
H
&
boxes_heap
,
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>&
selected_boxes_inside_class
,
std
::
vector
<
int64_t
>&
selected_indices
,
S
batch_boxes_start
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
std
::
size_t
batch_idx
,
std
::
size_t
class_idx
)
const
{
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
boxes_heap
.
pop
();
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
...
@@ -189,8 +230,8 @@ struct nonmaxsuppression
...
@@ -189,8 +230,8 @@ struct nonmaxsuppression
{
{
return
result
;
return
result
;
}
}
float
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
float
>
())
:
0.0
f
;
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
float
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
float
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
...
@@ -200,48 +241,21 @@ struct nonmaxsuppression
...
@@ -200,48 +241,21 @@ struct nonmaxsuppression
const
auto
num_classes
=
lens
[
1
];
const
auto
num_classes
=
lens
[
1
];
const
auto
num_boxes
=
lens
[
2
];
const
auto
num_boxes
=
lens
[
2
];
// boxes of a class with NMS applied [score, index]
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
// iterate over batches and classes
// iterate over batches and classes
shape
comp_s
{
shape
::
float
_type
,
{
num_batches
,
num_classes
}};
shape
comp_s
{
shape
::
double
_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
auto
batch_idx
=
idx
[
0
];
auto
batch_idx
=
idx
[
0
];
auto
class_idx
=
idx
[
1
];
auto
class_idx
=
idx
[
1
];
// index offset for this class
// index offset for this class
std
::
size_t
score_offset_ind
=
auto
scores_start
=
scores
.
begin
()
+
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
// iterator to first value of this batch
// index to first value of this batch
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
std
::
size_t
batch_boxes_ind
=
batch_idx
*
num_boxes
*
4
;
auto
boxes_heap
=
auto
boxes_heap
=
filter_boxes_by_score
(
scores
,
score_offset_ind
,
num_boxes
,
score_threshold
);
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
select_boxes
(
boxes_heap
,
selected_boxes_inside_class
,
selected_indices
,
batch_boxes_start
,
max_output_boxes_per_class
,
iou_threshold
,
batch_idx
,
class_idx
);
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
boxes
,
batch_boxes_ind
,
next_top_score
.
second
),
batch_box
(
boxes
,
batch_boxes_ind
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
boxes_heap
.
pop
();
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment