Commit 570db377 authored by Paul's avatar Paul
Browse files

Add assert for strides

parent ce23285b
...@@ -18,7 +18,7 @@ void shape_for_each(const rtg::shape& s, F f) ...@@ -18,7 +18,7 @@ void shape_for_each(const rtg::shape& s, F f)
s.strides().end(), s.strides().end(),
s.lens().begin(), s.lens().begin(),
indices.begin(), indices.begin(),
[&](std::size_t stride, std::size_t len) { return (i / stride) % len; }); [&](std::size_t stride, std::size_t len) { assert(len > 0 and stride > 0); return (i / stride) % len; });
call(indices); call(indices);
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream>
namespace rtg { namespace rtg {
...@@ -19,6 +20,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -19,6 +20,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and "At least one stride must be non-zero");
m_packed = this->elements() == this->element_space(); m_packed = this->elements() == this->element_space();
} }
...@@ -72,7 +74,7 @@ std::size_t shape::index(std::size_t i) const ...@@ -72,7 +74,7 @@ std::size_t shape::index(std::size_t i) const
this->strides().begin(), this->strides().begin(),
std::size_t{0}, std::size_t{0},
std::plus<std::size_t>{}, std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; }); [&](std::size_t len, std::size_t stride) { assert(stride > 0 and len > 0); return ((i / stride) % len) * stride; });
} }
bool shape::packed() const { return this->m_packed; } bool shape::packed() const { return this->m_packed; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
......
...@@ -149,60 +149,9 @@ struct cpu_contiguous ...@@ -149,60 +149,9 @@ struct cpu_contiguous
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
auto input_shape = args[0].get_shape(); shape_for_each(output.get_shape(), [&](const auto& idx) {
auto ndim = output_shape.lens().size(); output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data());
if(ndim == 2)
{
dfor(input_shape.lens()[0], input_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) { *ptr++ = input(i0, i1); });
}
else if(ndim == 3)
{
dfor(input_shape.lens()[0], input_shape.lens()[1], input_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(i0, i1, i2);
});
}
else if(ndim == 4)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0, i1, i2, i3);
}); });
}
else if(ndim == 5)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4) { *ptr++ = input(i0, i1, i2, i3, i4); });
}
else if(ndim == 6)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4],
input_shape.lens()[5])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4,
std::size_t i5) { *ptr++ = input(i0, i1, i2, i3, i4, i5); });
}
}); });
return result; return result;
} }
......
...@@ -42,18 +42,20 @@ function(add_test_command NAME EXE) ...@@ -42,18 +42,20 @@ function(add_test_command NAME EXE)
# -ex run # -ex run
# -ex backtrace # -ex backtrace
# --args $<TARGET_FILE:${EXE}> ${ARGN}) # --args $<TARGET_FILE:${EXE}> ${ARGN})
file(GENERATE OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmake" set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR})
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT " CONTENT "
execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} RESULT_VARIABLE RESULT) execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} WORKING_DIRECTORY ${TEST_DIR} RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL 0) if(NOT RESULT EQUAL 0)
# TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid # TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid
if(EXISTS core) if(EXISTS ${TEST_DIR}/core)
execute_process(COMMAND ${RTG_GDB} $<TARGET_FILE:${EXE}> core -batch -ex bt) execute_process(COMMAND ${RTG_GDB} $<TARGET_FILE:${EXE}> ${TEST_DIR}/core -batch -ex bt)
endif() endif()
message(FATAL_ERROR \"Test failed\") message(FATAL_ERROR \"Test failed\")
endif() endif()
") ")
add_test(NAME ${NAME} COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmake") add_test(NAME ${NAME} COMMAND ${CMAKE_COMMAND} -P "${TEST_DIR}/run.cmake")
else() else()
add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN}) add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN})
endif() endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment