Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6382ff10
Commit
6382ff10
authored
Jun 28, 2022
by
charlie
Browse files
Dynamic output shape works
parent
39f5fe15
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
9 deletions
+11
-9
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+7
-5
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+4
-4
No files found.
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
6382ff10
...
@@ -68,7 +68,7 @@ struct nonmaxsuppression
...
@@ -68,7 +68,7 @@ struct nonmaxsuppression
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic spatial dimension mismatch between "
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input"
);
"boxes and scores input"
);
}
}
if
(
boxes_dims
.
at
(
0
)
!=
inputs
.
at
(
1
).
lens
()[
0
]
)
if
(
boxes_dims
.
at
(
0
)
!=
scores_dims
.
at
(
0
)
)
{
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic number of batches mismatch between "
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input"
);
"boxes and scores input"
);
...
@@ -237,7 +237,9 @@ struct nonmaxsuppression
...
@@ -237,7 +237,9 @@ struct nonmaxsuppression
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
// make buffer of maximum size
shape
max_output_shape
=
{
output_shape
.
type
(),
output_shape
.
max_lens
()};
argument
result
{
max_output_shape
};
std
::
size_t
max_output_boxes_per_class
=
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
...
@@ -258,7 +260,7 @@ struct nonmaxsuppression
...
@@ -258,7 +260,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
max_
output_shape
.
elements
());
// iterate over batches and classes
// iterate over batches and classes
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
...
@@ -281,11 +283,11 @@ struct nonmaxsuppression
...
@@ -281,11 +283,11 @@ struct nonmaxsuppression
class_idx
);
class_idx
);
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
num_selected
=
selected_indices
.
size
();
num_selected
=
selected_indices
.
size
()
/
3
;
});
});
});
});
return
result
.
reshape
({
num_selected
,
3
});
return
result
.
reshape
(
{
output_shape
.
type
(),
{
num_selected
,
3
}
}
);
}
}
};
};
...
...
test/ref_ops_test.cpp
View file @
6382ff10
...
@@ -3433,7 +3433,7 @@ TEST_CASE(nms_not_center_test)
...
@@ -3433,7 +3433,7 @@ TEST_CASE(nms_not_center_test)
auto
output
=
p
.
eval
({}).
back
();
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
...
@@ -3466,7 +3466,7 @@ TEST_CASE(nms_test)
...
@@ -3466,7 +3466,7 @@ TEST_CASE(nms_test)
auto
output
=
p
.
eval
({}).
back
();
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
...
@@ -3503,7 +3503,7 @@ TEST_CASE(nms_transpose1_test)
...
@@ -3503,7 +3503,7 @@ TEST_CASE(nms_transpose1_test)
auto
output
=
p
.
eval
({}).
back
();
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
...
@@ -3540,7 +3540,7 @@ TEST_CASE(nms_transpose2_test)
...
@@ -3540,7 +3540,7 @@ TEST_CASE(nms_transpose2_test)
auto
output
=
p
.
eval
({}).
back
();
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment