Commit cbf4c8d6 authored by Paul's avatar Paul
Browse files

More formatting

parent 8f330074
......@@ -238,18 +238,21 @@ struct transpose
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if (dims.size() != input_lens.size()) {
if(dims.size() != input_lens.size())
{
RTG_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) {
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
RTG_THROW("Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
for (int i = 0; i < output_lens.size(); i++) {
output_lens[i] = input_lens[dims[i]];
for(int i = 0; i < output_lens.size(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {t, output_lens, output_strides};
......@@ -257,15 +260,16 @@ struct transpose
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct contiguous
struct contiguous
{
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if (lens.size() < 2) {
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if(lens.size() < 2)
{
RTG_THROW("Number of dimensions should exceed 1");
}
return {t, lens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment