Commit 65e6664a authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_contiguous' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents bc082a4b 95d0bc93
...@@ -74,7 +74,8 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma ...@@ -74,7 +74,8 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
# Install newer cmake for onnx runtime # Install newer cmake for onnx runtime
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 ARG CMAKE_VERSION=3.24.2
RUN cget -p /opt/cmake install -X binary https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=main ARG ONNXRUNTIME_BRANCH=main
......
...@@ -111,7 +111,21 @@ struct literal : raw_data<literal> ...@@ -111,7 +111,21 @@ struct literal : raw_data<literal>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements()); assert(std::distance(start, end) == m_shape.elements());
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
} }
}; };
......
...@@ -48,18 +48,14 @@ struct contiguous ...@@ -48,18 +48,14 @@ struct contiguous
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front(); auto s0 = inputs.front();
if(s0.dynamic()) if(s0.dynamic() or s0.standard())
{ {
return s0; return s0;
} }
else else
{ {
if(s0.standard()) const auto& lens = s0.lens();
{ auto t = s0.type();
return inputs.front();
}
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens}; return {t, lens};
} }
} }
......
...@@ -31,10 +31,6 @@ ...@@ -31,10 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -49,20 +49,6 @@ TEST_CASE(literal_test) ...@@ -49,20 +49,6 @@ TEST_CASE(literal_test)
EXPECT(l4.empty()); EXPECT(l4.empty());
} }
TEST_CASE(literal_nstd_shape)
{
migraphx::shape nstd_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> nstd_data(12);
std::iota(nstd_data.begin(), nstd_data.end(), 0);
migraphx::shape std_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
std::vector<float> std_data = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
auto l0 = migraphx::literal{nstd_shape, nstd_data};
auto l1 = migraphx::literal{std_shape, std_data};
EXPECT(l0 != l1);
}
TEST_CASE(literal_os1) TEST_CASE(literal_os1)
{ {
migraphx::literal l{1}; migraphx::literal l{1};
......
...@@ -918,42 +918,11 @@ TEST_CASE(contiguous_test) ...@@ -918,42 +918,11 @@ TEST_CASE(contiguous_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(contiguous_param_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
auto a = mm->add_parameter("X", a_shape);
mm->add_instruction(migraphx::make_op("contiguous"), a);
p.compile(migraphx::ref::target{});
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(a_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(contiguous_dyn_test) TEST_CASE(contiguous_dyn_test)
...@@ -973,10 +942,8 @@ TEST_CASE(contiguous_dyn_test) ...@@ -973,10 +942,8 @@ TEST_CASE(contiguous_dyn_test)
params["X"] = migraphx::argument(static_shape, data.data()); params["X"] = migraphx::argument(static_shape, data.data());
auto result = p.eval(params).back(); auto result = p.eval(params).back();
result.visit([&](auto output) { std::vector<size_t> new_strides = {12, 4, 2, 1};
std::vector<size_t> new_strides = {12, 4, 2, 1}; EXPECT(result.get_shape().strides() == new_strides);
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment