Commit 8a5bc2fb authored by Paul's avatar Paul
Browse files

Merge

parents 868230f5 bb0e04ce
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <algorithm> #include <algorithm>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
......
...@@ -24,14 +24,11 @@ ...@@ -24,14 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP #define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <migraphx/op/normalize_attribute.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,10 +24,9 @@ ...@@ -24,10 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP #define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP #define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -24,16 +24,11 @@ ...@@ -24,16 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,18 +24,11 @@ ...@@ -24,18 +24,11 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP #define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -68,8 +68,10 @@ struct operation ...@@ -68,8 +68,10 @@ struct operation
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`. * For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp> #include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp> #include <migraphx/op/get_tuple_elem.hpp>
...@@ -79,6 +80,7 @@ ...@@ -79,6 +80,7 @@
#include <migraphx/op/lstm.hpp> #include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp> #include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp> #include <migraphx/op/min.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
......
...@@ -24,38 +24,36 @@ ...@@ -24,38 +24,36 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility> #include <migraphx/config.hpp>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline void calculate_padding(int64_t idx, void calculate_padding(int64_t idx,
std::vector<int64_t>& pads, std::vector<int64_t>& pads,
int64_t input_dim, int64_t input_dim,
int64_t stride, int64_t stride,
int64_t dilation, int64_t dilation,
int64_t weight_dim, int64_t weight_dim,
bool is_same_upper = true) bool is_same_upper = true);
{
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = pads.size() / 2;
if(is_same_upper) /*!
{ * Calculate the padding for auto_padding. Used for dynamic shapes
pads[idx] = pad / 2; * where the padding calculation must be done at evaluation time.
pads[idx + pad_ndims] = pad - pad / 2; * \param tensor_lens input tensor image shape
} * \param k_lens weights kernel shape
else * \param strides strides for the kernel
{ * \param dilations dilations for the kernel
pads[idx + pad_ndims] = pad / 2; * \param use_upper put odd padding on upper or lower side
pads[idx] = pad - pad / 2; * \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
} */
} std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
std::vector<std::size_t> k_lens,
std::vector<std::size_t> strides,
std::vector<std::size_t> dilations,
bool use_upper = true);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP #define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <array> #include <array>
#include <numeric> #include <numeric>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <memory>
#include <unordered_map>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct sqlite_impl;
struct sqlite
{
sqlite() = default;
static sqlite read(const fs::path& p);
static sqlite write(const fs::path& p);
std::vector<std::unordered_map<std::string, std::string>> execute(const std::string& s);
private:
std::shared_ptr<sqlite_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
...@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input, ...@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input,
} }
template <class Iterator> template <class Iterator>
inline std::string to_string_range(Iterator start, Iterator last) inline std::string to_string_range(Iterator start, Iterator last, const char* delim = ", ")
{ {
std::stringstream ss; std::stringstream ss;
if(start != last) if(start != last)
{ {
ss << *start; ss << *start;
std::for_each(std::next(start), last, [&](auto&& x) { ss << ", " << x; }); std::for_each(std::next(start), last, [&](auto&& x) { ss << delim << x; });
} }
return ss.str(); return ss.str();
} }
template <class Range> template <class Range>
inline std::string to_string_range(const Range& r) inline std::string to_string_range(const Range& r, const char* delim = ", ")
{ {
return to_string_range(r.begin(), r.end()); return to_string_range(r.begin(), r.end(), delim);
} }
template <class T> template <class T>
inline std::string to_string_range(const std::initializer_list<T>& r) inline std::string to_string_range(const std::initializer_list<T>& r, const char* delim = ", ")
{ {
return to_string_range(r.begin(), r.end()); return to_string_range(r.begin(), r.end(), delim);
} }
template <class T> template <class T>
......
...@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto val = op.to_value(); auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>(); auto op_padding = val.at("padding").to_vector<size_t>();
// skip if shape is dynamic
if(input->get_shape().dynamic())
{
return;
}
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(), if(std::equal(op_padding.begin(),
op_padding.begin() + kdims, op_padding.begin() + kdims,
......
...@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const ...@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation o = this->get_operator(); operation o = this->get_operator();
if(this->need_normalization()) if(this->need_normalization())
{ {
auto lens = this->inputs().front()->get_shape().lens(); auto s = this->inputs().front()->get_shape();
if(!normalize_attributes(o, lens)) if(!normalize_attributes(o, s.max_lens()))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
...@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const ...@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if(inputs.empty()) if(inputs.empty())
continue; continue;
auto lens = inputs[0]->get_shape().lens(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens)) if(normalize_attributes(tuned_op, s.max_lens()))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -93,9 +93,10 @@ struct onnx_parser ...@@ -93,9 +93,10 @@ struct onnx_parser
onnx_parser&, const node_info&, std::vector<instruction_ref>)>; onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
std::size_t default_dim_value = 1; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
...@@ -118,6 +119,7 @@ struct onnx_parser ...@@ -118,6 +119,7 @@ struct onnx_parser
}; };
shape::type_t get_type(int dtype); shape::type_t get_type(int dtype);
bool is_type_float(shape::type_t dtype);
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,8 +41,25 @@ template <class... Ts> ...@@ -41,8 +41,25 @@ template <class... Ts>
program parse_onnx_from(const onnx_options& options, Ts&&... xs) program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
onnx::onnx_parser parser; onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val != 0)
{
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0})
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value");
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
}
else
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
......
...@@ -28,16 +28,17 @@ ...@@ -28,16 +28,17 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -58,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const ...@@ -58,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
...@@ -75,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t ...@@ -75,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// scalar input // scalar input
...@@ -255,6 +256,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -255,6 +256,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
...@@ -268,7 +274,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -268,7 +274,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(mod_insts, name)) if(!contains(mod_insts, name))
{ {
// ONNX specification does not specify hwo to deal with the // ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
...@@ -278,13 +284,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -278,13 +284,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"\" existing in parent graph!"); "\" existing in parent graph!");
} }
shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
s = parse_type(input.type(), dims);
} }
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
...@@ -439,30 +454,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -439,30 +454,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return {shape_type, input_dims}; return {shape_type, input_dims};
} }
std::vector<std::size_t> dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dynamic_dims),
[&](auto&& d) -> std::size_t { [&](auto&& d) -> shape::dynamic_dimension {
if(d.has_dim_value()) if(d.has_dim_value())
{ {
if(static_cast<int>(d.dim_value()) <= 0) if(static_cast<int>(d.dim_value()) <= 0)
{ {
return default_dim_value; return default_dyn_dim_value;
} }
return d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0};
} }
else else
{ {
return default_dim_value; return default_dyn_dim_value;
} }
}); });
if(dims.empty()) if(dynamic_dims.empty())
{
return {shape_type}; return {shape_type};
}
return {shape_type, dims}; if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
...@@ -487,6 +513,16 @@ shape::type_t get_type(int dtype) ...@@ -487,6 +513,16 @@ shape::type_t get_type(int dtype)
} }
} }
bool is_type_float(shape::type_t dtype)
{
bool r = false;
if(dtype == shape::float_type || dtype == shape::double_type || dtype == shape::half_type)
{
r = true;
}
return r;
}
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant> ...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
......
...@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto op = make_op(opd.op_name); auto op = make_op(opd.op_name);
auto values = op.to_value(); auto values = op.to_value();
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1]; auto weights = args[1];
auto in_lens = l0->get_shape().lens(); auto l0_shape = l0->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = l0_shape.max_lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2; auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV"); check_padding_mode(info, "CONV");
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
...@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
copy(info.attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings"); check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
} }
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto weight_lens = weights->get_shape().lens(); bool is_same_padding = false;
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end()); auto auto_pad = info.attributes["auto_pad"].s();
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
values["padding_mode"] = to_value(op::padding_mode_t::same); is_same_padding = true;
}
// check if image shape is dynamic
bool image_shape_dynamic = false;
if(l0_shape.dynamic())
{
auto dyn_dims = l0_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
image_shape_dynamic = true;
}
});
}
// check if kernel shape is dynamic
bool kernel_shape_dynamic = false;
if(w_shape.dynamic())
{
auto dyn_dims = w_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
kernel_shape_dynamic = true;
}
});
}
if(is_same_padding)
{
if(image_shape_dynamic or kernel_shape_dynamic)
{
// must calculate "same" padding with input shape data
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper
? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
values["use_dynamic_same_auto_pad"] = true;
}
else
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto weight_lens = weights->get_shape().max_lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
}
} }
} }
values["padding"] = std::vector<size_t>(padding.begin(), padding.end()); values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment