Unverified Commit dfaab007 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Make contiguous preserve scalar shape (#500)



* Make contiguous preserve scalar shape

* Formatting

* Improve standard shape calulation

* Formatting

* Enable some optimizations on debug build

* Up optimization level

* Remove debug symbols
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent ffde465d
......@@ -106,7 +106,7 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build ->
stage('Clang Debug') {
// TODO: Enable integer
def sanitizers = "undefined"
def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def debug_flags = "-O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
}
}, clang_release: rocmnode('vega') { cmake_build ->
......
......@@ -28,6 +28,8 @@ struct contiguous
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
......
......@@ -32,10 +32,8 @@ struct shape_impl
assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard =
this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()) and
std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 0; });
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
......
......@@ -117,6 +117,13 @@ TEST_CASE(contiguous_shape)
expect_shape(single, migraphx::op::contiguous{}, single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::op::contiguous{}, input);
}
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
......
......@@ -92,6 +92,33 @@ TEST_CASE(test_shape_overlap3)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_scalar1)
{
migraphx::shape s{migraphx::shape::float_type};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_scalar2)
{
migraphx::shape s{migraphx::shape::float_type, {1}, {0}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_scalar_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3, 3}, {0, 0, 0, 0}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment