Commit 3e46f548 authored by charlie's avatar charlie
Browse files

Removed the automatic conversion to a standard shape

parent 0af471ff
......@@ -113,24 +113,8 @@ struct literal : raw_data<literal>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
// make the literal into a standard shape (contiguous)
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each_nstd(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
m_shape = {m_shape.type(), m_shape.lens()};
}
}
};
template <class F>
......
......@@ -55,31 +55,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices);
}
}
/**
* Iterates the given function over the given shape indices.
* Will iterate using non-standard strides if given a non-standard shape.
*/
template <class F>
void shape_for_each_nstd(const migraphx::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0; i < s.elements(); i++)
{
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
call(indices);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -60,7 +60,7 @@ TEST_CASE(literal_nstd_shape)
auto l0 = migraphx::literal{nstd_shape, nstd_data};
auto l1 = migraphx::literal{std_shape, std_data};
EXPECT(l0 == l1);
EXPECT(l0 != l1);
}
TEST_CASE(literal_os1)
......
......@@ -835,12 +835,36 @@ TEST_CASE(concat_test)
}
}
TEST_CASE(contiguous_param_test)
TEST_CASE(contiguous_test)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
mm->add_instruction(migraphx::make_op("contiguous"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(contiguous_param_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
auto a = mm->add_parameter("X", a_shape);
mm->add_instruction(migraphx::make_op("contiguous"), a);
p.compile(migraphx::ref::target{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment