Commit 5f3f6e73 authored by charlie's avatar charlie
Browse files

Detach from non-std literal PR

parent 2cf7ae45
...@@ -111,7 +111,21 @@ struct literal : raw_data<literal> ...@@ -111,7 +111,21 @@ struct literal : raw_data<literal>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements()); assert(std::distance(start, end) == m_shape.elements());
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
} }
}; };
......
...@@ -31,10 +31,6 @@ ...@@ -31,10 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -49,20 +49,6 @@ TEST_CASE(literal_test) ...@@ -49,20 +49,6 @@ TEST_CASE(literal_test)
EXPECT(l4.empty()); EXPECT(l4.empty());
} }
TEST_CASE(literal_nstd_shape)
{
migraphx::shape nstd_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> nstd_data(12);
std::iota(nstd_data.begin(), nstd_data.end(), 0);
migraphx::shape std_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
std::vector<float> std_data = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
auto l0 = migraphx::literal{nstd_shape, nstd_data};
auto l1 = migraphx::literal{std_shape, std_data};
EXPECT(l0 != l1);
}
TEST_CASE(literal_os1) TEST_CASE(literal_os1)
{ {
migraphx::literal l{1}; migraphx::literal l{1};
......
...@@ -848,16 +848,11 @@ TEST_CASE(contiguous_test) ...@@ -848,16 +848,11 @@ TEST_CASE(contiguous_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, data));
} }
TEST_CASE(contiguous_param_test) TEST_CASE(contiguous_param_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment