Commit 8f0108f8 authored by Paul's avatar Paul
Browse files

Add reshape operator

parent 9f046d67
......@@ -109,6 +109,38 @@ struct activation
}
};
struct reshape
{
std::vector<int64_t> dims;
std::string name() const
{
return "reshape[dims={" + to_string(dims) +
"}]";
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0;i < dims.size();i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
}
if(dims.back() == -1)
{
rdims.pop_back();
std::copy(idims.begin()+rdims.size(), idims.end(), std::back_inserter(rdims));
}
return {inputs.front().type(), rdims};
}
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
};
} // namespace rtg
......
......@@ -89,6 +89,15 @@ struct onnx_parser
add_op("Relu", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args);
});
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v)
{
copy(v, std::back_inserter(op.dims));
});
return prog->add_instruction(op, args);
});
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
rtg::literal v = parse_value(attributes.at("value"));
return prog->add_literal(v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment