"...resnet50_tensorflow.git" did not exist on "90bb20227e6e2bc65deda56bf96acb896bda2699"
Commit cbf4c8d6 authored by Paul's avatar Paul
Browse files

More formatting

parent 8f330074
...@@ -238,18 +238,21 @@ struct transpose ...@@ -238,18 +238,21 @@ struct transpose
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if (dims.size() != input_lens.size()) { if(dims.size() != input_lens.size())
{
RTG_THROW("Permutation has wrong number of axes"); RTG_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) { if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
RTG_THROW("Invalid permutation"); RTG_THROW("Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size()); std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
for (int i = 0; i < output_lens.size(); i++) { for(int i = 0; i < output_lens.size(); i++)
output_lens[i] = input_lens[dims[i]]; {
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
...@@ -257,15 +260,16 @@ struct transpose ...@@ -257,15 +260,16 @@ struct transpose
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct contiguous struct contiguous
{ {
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
if (lens.size() < 2) { if(lens.size() < 2)
{
RTG_THROW("Number of dimensions should exceed 1"); RTG_THROW("Number of dimensions should exceed 1");
} }
return {t, lens}; return {t, lens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment