Commit 0af471ff authored by charlie's avatar charlie
Browse files

Fixing non-standard shape literal construction

parent 1820198e
......@@ -40,6 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
/**
* @brief Represents a raw literal
* @details This stores the literal has a raw buffer that is owned by this class
* If the given shape is non-standard, the literal will be converted to a standard shape at
* construction.
*/
struct literal : raw_data<literal>
{
......@@ -117,14 +119,16 @@ struct literal : raw_data<literal>
}
else
{
// make the literal into a standard shape (contiguous)
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
shape_for_each_nstd(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
m_shape = {m_shape.type(), m_shape.lens()};
}
}
};
......
......@@ -31,6 +31,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template <class F>
void shape_for_each(const migraphx::shape& s, F f)
{
......@@ -52,6 +56,30 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
/**
* Iterates the given function over the given shape indices.
* Will iterate using non-standard strides if given a non-standard shape.
*/
template <class F>
void shape_for_each_nstd(const migraphx::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0; i < s.elements(); i++)
{
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
call(indices);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -49,6 +49,20 @@ TEST_CASE(literal_test)
EXPECT(l4.empty());
}
TEST_CASE(literal_nstd_shape)
{
migraphx::shape nstd_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> nstd_data(12);
std::iota(nstd_data.begin(), nstd_data.end(), 0);
migraphx::shape std_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
std::vector<float> std_data = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
auto l0 = migraphx::literal{nstd_shape, nstd_data};
auto l1 = migraphx::literal{std_shape, std_data};
EXPECT(l0 == l1);
}
TEST_CASE(literal_os1)
{
migraphx::literal l{1};
......
......@@ -835,24 +835,31 @@ TEST_CASE(concat_test)
}
}
TEST_CASE(contiguous_test)
TEST_CASE(contiguous_param_test)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
mm->add_instruction(migraphx::make_op("contiguous"), l);
auto a = mm->add_parameter("X", a_shape);
mm->add_instruction(migraphx::make_op("contiguous"), a);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(a_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, data));
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(conv_dynamic_batch_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment