Commit cbf4c8d6 authored by Paul's avatar Paul
Browse files

More formatting

parent 8f330074
...@@ -238,17 +238,20 @@ struct transpose ...@@ -238,17 +238,20 @@ struct transpose
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if (dims.size() != input_lens.size()) { if(dims.size() != input_lens.size())
{
RTG_THROW("Permutation has wrong number of axes"); RTG_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) { if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
RTG_THROW("Invalid permutation"); RTG_THROW("Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size()); std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
for (int i = 0; i < output_lens.size(); i++) { for(int i = 0; i < output_lens.size(); i++)
{
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
...@@ -265,7 +268,8 @@ struct contiguous ...@@ -265,7 +268,8 @@ struct contiguous
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
if (lens.size() < 2) { if(lens.size() < 2)
{
RTG_THROW("Number of dimensions should exceed 1"); RTG_THROW("Number of dimensions should exceed 1");
} }
return {t, lens}; return {t, lens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment