Commit 65e6664a authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_contiguous' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents bc082a4b 95d0bc93
......@@ -74,7 +74,8 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
# Install newer cmake for onnx runtime
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
ARG CMAKE_VERSION=3.24.2
RUN cget -p /opt/cmake install -X binary https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=main
......
......@@ -111,7 +111,21 @@ struct literal : raw_data<literal>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
}
};
......
......@@ -48,18 +48,14 @@ struct contiguous
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();
if(s0.dynamic())
if(s0.dynamic() or s0.standard())
{
return s0;
}
else
{
if(s0.standard())
{
return inputs.front();
}
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
const auto& lens = s0.lens();
auto t = s0.type();
return {t, lens};
}
}
......
......@@ -31,10 +31,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template <class F>
void shape_for_each(const migraphx::shape& s, F f)
{
......@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -49,20 +49,6 @@ TEST_CASE(literal_test)
EXPECT(l4.empty());
}
TEST_CASE(literal_nstd_shape)
{
migraphx::shape nstd_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> nstd_data(12);
std::iota(nstd_data.begin(), nstd_data.end(), 0);
migraphx::shape std_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
std::vector<float> std_data = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
auto l0 = migraphx::literal{nstd_shape, nstd_data};
auto l1 = migraphx::literal{std_shape, std_data};
EXPECT(l0 != l1);
}
TEST_CASE(literal_os1)
{
migraphx::literal l{1};
......
......@@ -918,42 +918,11 @@ TEST_CASE(contiguous_test)
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(contiguous_param_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
auto a = mm->add_parameter("X", a_shape);
mm->add_instruction(migraphx::make_op("contiguous"), a);
p.compile(migraphx::ref::target{});
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(a_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(contiguous_dyn_test)
......@@ -973,10 +942,8 @@ TEST_CASE(contiguous_dyn_test)
params["X"] = migraphx::argument(static_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(result.get_shape().strides() == new_strides);
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment