Commit 7ebb1cbf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add checks in helper functions

parent e1f959fd
......@@ -203,7 +203,10 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 8;
assert(num_problem_args == size(problem.shape) + size(problem.axes));
if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
{
return false;
}
if(argc == 1)
{
......@@ -265,7 +268,10 @@ template <typename Shape, typename Indices>
inline std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool>
is_valid_indices(const Shape& shape, const Indices& indices)
{
assert(is_valid_shape(shape));
if(!is_valid_shape(shape))
{
return false;
}
using std::empty;
if(empty(indices))
......@@ -320,9 +326,10 @@ std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_ran
advance_indices(const Shape& shape, Indices& indices)
{
using std::size;
assert(is_valid_shape(shape));
assert(is_valid_indices(indices));
assert(size(shape) == size(indices));
if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
{
return false;
}
bool carry = true;
......@@ -340,29 +347,70 @@ advance_indices(const Shape& shape, Indices& indices)
return !carry;
}
template <typename Src, typename Functor, typename Dest>
std::enable_if_t<std::is_invocable_v<Functor,
std::add_lvalue_reference_t<Dest>,
std::add_lvalue_reference_t<Src>>>
host_elementwise_permute(const Tensor<Src>& src, Functor functor, Tensor<Dest>& dest)
template <typename Src, typename Axes, typename Functor, typename Dest>
std::enable_if_t<detail::is_random_access_range_v<Axes> && detail::is_sized_range_v<Axes> &&
std::is_invocable_v<Functor,
std::add_lvalue_reference_t<Dest>,
std::add_lvalue_reference_t<Src>>,
bool>
host_elementwise_permute(const Tensor<Src>& src,
const Axes& axes,
Functor functor,
Tensor<Dest>& dest)
{
const auto& shape = src.mDesc.GetLengths();
const auto& transposed_shape = dest.mDesc.GetLengths();
assert(is_valid_shape(shape) && is_valid_shape(transposed_shape));
std::copy(begin(shape), end(shape), std::ostream_iterator<std::size_t>(std::cerr, " "));
std::cerr << std::endl;
std::copy(begin(transposed_shape),
end(transposed_shape),
std::ostream_iterator<std::size_t>(std::cerr, " "));
std::cerr << std::endl;
if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape)))
{
return false;
}
using std::size;
if(!(is_valid_axes(axes) && size(axes) == 4))
{
return false;
}
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
using std::size;
assert(size(shape) == 4 && size(transposed_shape) == 4);
if(!(size(shape) == 4 && size(transposed_shape) == 4))
{
return false;
}
static_assert(detail::is_random_access_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_random_access_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
{
for(std::size_t idx = 0; idx < size(shape); ++idx)
{
if(transposed_shape[idx] != shape[axes[idx]])
{
return false;
}
}
}
std::array<std::size_t, 4> indices{};
assert(is_valid_indices(indices));
if(!is_valid_indices(shape, indices))
{
return false;
}
do
{
Dest b_val = 0;
functor(b_val, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[0], indices[2], indices[3], indices[1]) = b_val;
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
} while(advance_indices(shape, indices));
return true;
}
......@@ -49,7 +49,7 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
if(config.do_verification)
{
Tensor<BDataType> host_b(nhwc);
host_elementwise_permute(a, PassThrough{}, host_b);
host_elementwise_permute(a, problem.axes, PassThrough{}, host_b);
b_device_buf.FromDevice(b.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment