Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0af471ff
Commit
0af471ff
authored
Nov 02, 2022
by
charlie
Browse files
Fixing non-standard shape literal construction
parent
1820198e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
10 deletions
+63
-10
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+5
-1
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+28
-0
test/literal_test.cpp
test/literal_test.cpp
+14
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+16
-9
No files found.
src/include/migraphx/literal.hpp
View file @
0af471ff
...
@@ -40,6 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,6 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
/**
/**
* @brief Represents a raw literal
* @brief Represents a raw literal
* @details This stores the literal has a raw buffer that is owned by this class
* @details This stores the literal has a raw buffer that is owned by this class
* If the given shape is non-standard, the literal will be converted to a standard shape at
* construction.
*/
*/
struct
literal
:
raw_data
<
literal
>
struct
literal
:
raw_data
<
literal
>
{
{
...
@@ -117,14 +119,16 @@ struct literal : raw_data<literal>
...
@@ -117,14 +119,16 @@ struct literal : raw_data<literal>
}
}
else
else
{
{
// make the literal into a standard shape (contiguous)
auto
it
=
start
;
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
_nstd
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
it
++
;
});
});
});
});
m_shape
=
{
m_shape
.
type
(),
m_shape
.
lens
()};
}
}
}
}
};
};
...
...
src/include/migraphx/shape_for_each.hpp
View file @
0af471ff
...
@@ -31,6 +31,10 @@
...
@@ -31,6 +31,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template
<
class
F
>
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
{
...
@@ -52,6 +56,30 @@ void shape_for_each(const migraphx::shape& s, F f)
...
@@ -52,6 +56,30 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
}
}
/**
* Iterates the given function over the given shape indices.
* Will iterate using non-standard strides if given a non-standard shape.
*/
template
<
class
F
>
void
shape_for_each_nstd
(
const
migraphx
::
shape
&
s
,
F
f
)
{
// Ensure calls to f use const ref to vector
auto
call
=
[
&
f
](
const
std
::
vector
<
std
::
size_t
>&
i
)
{
f
(
i
);
};
std
::
vector
<
std
::
size_t
>
indices
(
s
.
lens
().
size
());
for
(
std
::
size_t
i
=
0
;
i
<
s
.
elements
();
i
++
)
{
std
::
transform
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
s
.
lens
().
begin
(),
indices
.
begin
(),
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
call
(
indices
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
test/literal_test.cpp
View file @
0af471ff
...
@@ -49,6 +49,20 @@ TEST_CASE(literal_test)
...
@@ -49,6 +49,20 @@ TEST_CASE(literal_test)
EXPECT
(
l4
.
empty
());
EXPECT
(
l4
.
empty
());
}
}
TEST_CASE
(
literal_nstd_shape
)
{
migraphx
::
shape
nstd_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
nstd_data
(
12
);
std
::
iota
(
nstd_data
.
begin
(),
nstd_data
.
end
(),
0
);
migraphx
::
shape
std_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
std
::
vector
<
float
>
std_data
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
auto
l0
=
migraphx
::
literal
{
nstd_shape
,
nstd_data
};
auto
l1
=
migraphx
::
literal
{
std_shape
,
std_data
};
EXPECT
(
l0
==
l1
);
}
TEST_CASE
(
literal_os1
)
TEST_CASE
(
literal_os1
)
{
{
migraphx
::
literal
l
{
1
};
migraphx
::
literal
l
{
1
};
...
...
test/ref_ops_test.cpp
View file @
0af471ff
...
@@ -835,24 +835,31 @@ TEST_CASE(concat_test)
...
@@ -835,24 +835,31 @@ TEST_CASE(concat_test)
}
}
}
}
TEST_CASE(contiguous_test)
TEST_CASE(contiguous_
param_
test)
{
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
auto
l
= mm->add_
literal(migraphx::literal{a_shape, data}
);
auto
a
= mm->add_
parameter("X", a_shape
);
mm->add_instruction(migraphx::make_op("contiguous"),
l
);
mm->add_instruction(migraphx::make_op("contiguous"),
a
);
p.compile(migraphx::ref::target{});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(a_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(results_vector, data));
}
}
TEST_CASE(conv_dynamic_batch_test)
TEST_CASE(conv_dynamic_batch_test)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment