Commit 570db377 authored by Paul's avatar Paul
Browse files

Add assert for strides

parent ce23285b
......@@ -18,7 +18,7 @@ void shape_for_each(const rtg::shape& s, F f)
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) { return (i / stride) % len; });
[&](std::size_t stride, std::size_t len) { assert(len > 0 and stride > 0); return (i / stride) % len; });
call(indices);
}
}
......
......@@ -4,6 +4,7 @@
#include <numeric>
#include <algorithm>
#include <functional>
#include <iostream>
namespace rtg {
......@@ -19,6 +20,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and "At least one stride must be non-zero");
m_packed = this->elements() == this->element_space();
}
......@@ -72,7 +74,7 @@ std::size_t shape::index(std::size_t i) const
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
[&](std::size_t len, std::size_t stride) { assert(stride > 0 and len > 0); return ((i / stride) % len) * stride; });
}
bool shape::packed() const { return this->m_packed; }
std::size_t shape::element_space() const
......
......@@ -149,60 +149,9 @@ struct cpu_contiguous
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
auto input_shape = args[0].get_shape();
auto ndim = output_shape.lens().size();
using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data());
if(ndim == 2)
{
dfor(input_shape.lens()[0], input_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) { *ptr++ = input(i0, i1); });
}
else if(ndim == 3)
{
dfor(input_shape.lens()[0], input_shape.lens()[1], input_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(i0, i1, i2);
});
}
else if(ndim == 4)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0, i1, i2, i3);
});
}
else if(ndim == 5)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4) { *ptr++ = input(i0, i1, i2, i3, i4); });
}
else if(ndim == 6)
{
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4],
input_shape.lens()[5])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4,
std::size_t i5) { *ptr++ = input(i0, i1, i2, i3, i4, i5); });
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
......
......@@ -42,18 +42,20 @@ function(add_test_command NAME EXE)
# -ex run
# -ex backtrace
# --args $<TARGET_FILE:${EXE}> ${ARGN})
file(GENERATE OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmake"
set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR})
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT "
execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} RESULT_VARIABLE RESULT)
execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} WORKING_DIRECTORY ${TEST_DIR} RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL 0)
# TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid
if(EXISTS core)
execute_process(COMMAND ${RTG_GDB} $<TARGET_FILE:${EXE}> core -batch -ex bt)
if(EXISTS ${TEST_DIR}/core)
execute_process(COMMAND ${RTG_GDB} $<TARGET_FILE:${EXE}> ${TEST_DIR}/core -batch -ex bt)
endif()
message(FATAL_ERROR \"Test failed\")
endif()
")
add_test(NAME ${NAME} COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_BINARY_DIR}/test_${NAME}.cmake")
add_test(NAME ${NAME} COMMAND ${CMAKE_COMMAND} -P "${TEST_DIR}/run.cmake")
else()
add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN})
endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment