"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a47adf4b7b6adc2f7498c3c5e0a6aa4ead57e09e"
Commit b9d37172 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 1af66a1c ea62d7aa
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt ...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
{ {
write_buffer(filename, save_buffer(p, options)); write_buffer(filename, save_buffer(p, options));
} }
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void print_miopen_warning(const program& p)
{
auto mods = p.get_modules();
if(std::any_of(mods.begin(), mods.end(), [](const auto* m) {
return std::any_of(m->begin(), m->end(), [](const instruction& i) {
return i.name() == "gpu::miopen_fusion";
});
}))
{
std::cout << "[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<< std::endl;
;
}
}
std::vector<char> save_buffer(const program& p, const file_options& options) std::vector<char> save_buffer(const program& p, const file_options& options)
{ {
value v = p.to_value(); value v = p.to_value();
print_miopen_warning(p);
std::vector<char> buffer; std::vector<char> buffer;
if(options.format == "msgpack") if(options.format == "msgpack")
{ {
......
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
*/ */
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const ...@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto s = ins->get_shape(); auto s = ins->get_shape();
std::size_t offset = seg.first * alignment; std::size_t offset = seg.first * alignment;
assert(offset < n); assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem); m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
} }
// Replace zero allocation // Replace zero allocation
...@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const ...@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if(ins->name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
assert(ins->get_shape().bytes() == 0); assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem); m.replace_instruction(
ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem);
} }
// Remove scratch parameter if its not used // Remove scratch parameter if its not used
......
...@@ -873,12 +873,11 @@ module::print_py(std::ostream& os, ...@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
os << mname << ".add_literal("; os << mname << ".add_literal(";
bool use_abs = false; const bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now // Disable abs for now
use_abs = false; // ins->get_literal().visit([&](auto v) {
// use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
// });
if(use_abs) if(use_abs)
os << "migraphx.abs_literal("; os << "migraphx.abs_literal(";
os << "migraphx.generate_argument("; os << "migraphx.generate_argument(";
......
...@@ -25,6 +25,33 @@ ...@@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <msgpack.hpp> #include <msgpack.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Leave an extra byte for error checking
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;
template <class Range>
std::size_t msgpack_chunk_size(const Range& r)
{
return 1 + (r.size() - 1) / msgpack_size_limit;
}
template <class Iterator, class F>
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
{
while(std::distance(start, last) > msgpack_size_limit)
{
auto next = std::next(start, msgpack_size_limit);
f(start, next);
start = next;
}
f(start, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace msgpack { namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break; break;
} }
case msgpack::type::BIN: { case msgpack::type::BIN: {
// For backwards compatibility
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: { case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{}; if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
std::for_each( {
o.via.array.ptr, auto bin = migraphx::value::binary{};
o.via.array.ptr + o.via.array.size, std::for_each(
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); }); o.via.array.ptr,
v = r; o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) {
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
});
v = bin;
}
else
{
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
}
break; break;
} }
case msgpack::type::MAP: { case msgpack::type::MAP: {
...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
const auto* data = reinterpret_cast<const char*>(x.data()); const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size(); auto size = x.size();
o.pack_bin(size); o.pack_array(migraphx::msgpack_chunk_size(x));
o.pack_bin_body(data, size); migraphx::msgpack_chunk_for_each(
data, data + size, [&](const char* start, const char* last) {
o.pack_bin(last - start);
o.pack_bin_body(start, last - start);
});
return o; return o;
} }
}; };
...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o.pack_array(0); o.pack_array(0);
return; return;
} }
if(v.size() > migraphx::msgpack_size_limit)
MIGRAPHX_THROW("Size is too large for msgpack");
if(not v.front().get_key().empty()) if(not v.front().get_key().empty())
{ {
o.pack_map(v.size()); o.pack_map(v.size());
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
Message m) Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
if(result.empty())
{
return result;
};
int64_t n_rank = input_shape.ndim(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
...@@ -188,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape) ...@@ -188,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape)
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding")) if(attrs.contains("normalize_padding"))
{ {
auto padding = val.at(attrs.at("normalize_padding").to<std::string>()); bool use_auto_padding =
auto padding_size = padding.size(); (val.contains("padding_mode") and
auto padding_start = 2; (val.at("padding_mode").to<int>() != migraphx::op::padding_mode_t::default_));
if(not use_auto_padding)
if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true;
else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size");
else
{ {
auto result = tune_pad_attribute(padding); auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
val["padding"] = result; auto padding_size = padding.size();
op.from_value(val); auto padding_start = 2;
tuned = true; if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true;
else if(padding_size != (input_shape.ndim() - padding_start))
{
MIGRAPHX_THROW("normalize_attributes: inconsistent padding vector size ");
}
else
{
auto result = tune_pad_attribute(padding);
val["padding"] = result;
op.from_value(val);
tuned = true;
}
} }
} }
if(not attrs.contains("normalize_axes")) if(not attrs.contains("normalize_axes"))
...@@ -251,5 +262,22 @@ bool normalize_attributes(operation& op, const shape& input_shape) ...@@ -251,5 +262,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
return tuned; return tuned;
} }
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(axes, {}, attr_val, input_shape, [&] { return prefix; });
}
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(indices, axes, attr_val, input_shape, [&] { return prefix; });
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -117,6 +117,7 @@ struct onnx_parser ...@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false); parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
}; };
......
...@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser, ...@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
} }
shape s; shape s;
std::vector<std::size_t> dims;
if(parser.map_input_dims.count(name) > 0) if(parser.map_input_dims.count(name) > 0)
{ {
dims = parser.map_input_dims.at(name); std::vector<std::size_t> dims = parser.map_input_dims.at(name);
s = parser.parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(parser.map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
...@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser, ...@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
} }
else else
{ {
s = parser.parse_type(input.type(), dims); s = parser.parse_type(input.type());
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
...@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
} }
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t, shape onnx_parser::parse_type(const onnx::TypeProto& t) const
const std::vector<std::size_t>& input_dims) const
{ {
shape::type_t shape_type = get_type(t.tensor_type().elem_type()); shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(not input_dims.empty())
{
return {shape_type, input_dims};
}
std::vector<shape::dynamic_dimension> dynamic_dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
...@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return shape_from_dyn_dims(shape_type, dynamic_dims); return shape_from_dyn_dims(shape_type, dynamic_dims);
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(input_dims.empty())
return {shape_type};
return {shape_type, input_dims};
}
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
{ {
switch(dtype) switch(dtype)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_castlike : op_parser<parse_castlike>
{
std::vector<op_desc> operators() const { return {{"CastLike"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
if(not(args.size() == 2))
{
MIGRAPHX_THROW("PARSE_CASTLIKE: CastLike must have exactly 2 inputs!");
}
shape::type_t target_type = args[1]->get_shape().type();
return info.add_instruction(make_op("convert", {{"target_type", target_type}}), args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
...@@ -49,15 +49,14 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -49,15 +49,14 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{ {
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
} }
// convert to a scalar literal
l_val = literal(shape{l_val.get_shape().type(), {1}, {0}}, l_val.data());
} }
else else
{ {
l_val = literal({shape::float_type, {1}, {0}}, {0.0f}); l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
} }
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty()) if(args.empty())
{ {
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!"); MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
...@@ -65,30 +64,39 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -65,30 +64,39 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
else else
{ {
migraphx::shape s; migraphx::shape s;
// empty input tensor, output is a scalar // input is empty, output is a scalar
if(args[0]->get_shape().elements() == 0) auto type = l_val.get_shape().type();
migraphx::argument input = args[0]->eval();
if(not input.empty())
{ {
s = migraphx::shape{type, {1}, {0}}; // empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
std::vector<std::size_t> dims;
input.visit([&](auto ia) { dims.assign(ia.begin(), ia.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
// has variable input (dynamic shape buffer)
else else
{ {
migraphx::argument in = args[0]->eval(); auto dv_lit = info.add_literal(l_val);
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported"); auto alloc_ins =
info.add_instruction(make_op("allocate", {{"buf_type", type}}), args[0]);
std::vector<std::size_t> dims; return info.add_instruction(make_op("fill"), dv_lit, alloc_ins);
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
} }
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings"); kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
} }
if(contains(info.attributes, "auto_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
}
else
{
values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_shape.lens(),
paddings);
}
}
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
{ {
paddings.resize(kdims * 2); paddings.resize(kdims * 2);
...@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling>
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start; std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end; std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start); tune_padding_size(values, paddings, count_include_pad, slice_start);
...@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling>
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f}; op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()}); shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information // compute slice_end information
slice_end.resize(slice_start.size()); slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2, std::transform(out_lens.begin() + 2,
......
...@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "seed")) if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f()); gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low); std::uniform_real_distribution<> d(low, high);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static std::vector<int> static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind, calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim, int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims, std::vector<std::vector<std::size_t>> vec_dims,
const shape& in_s) const shape& in_s)
{ {
if(i_dim == vvv_ind.size()) if(i_dim == vvv_ind.size())
{ {
std::vector<int> vec_ind; std::vector<int> vec_ind(vec_dims.size());
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx)); return static_cast<int>(in_s.index(idx));
}); });
return vec_ind; return vec_ind;
} }
const auto& vv_ind = vvv_ind[i_dim]; const auto& vv_lo = vvv_ind[i_dim][0];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1; std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
}); });
} }
const auto& vv_hi = vv_ind.at(1); const auto& vv_hi = vvv_ind[i_dim][1];
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size())
{ {
std::transform(vv_hi.begin(), std::transform(vv_hi.begin(),
vv_hi.end(), vv_hi.end(),
...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return dim; return dim;
}); });
} }
vec_dims.clear();
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s); return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s);
} }
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!"); "PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_" + opd.op_name + "PARSE_" + opd.op_name +
": dynamic input scale is not supported!"); ": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize> ...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx // map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode); auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx; std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val); in_idx[ii] = nearest_op(in_lens[ii], idx_val);
} }
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
}); });
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
...@@ -327,20 +324,18 @@ struct parse_resize : op_parser<parse_resize> ...@@ -327,20 +324,18 @@ struct parse_resize : op_parser<parse_resize>
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind); std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements)); std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
} }
}); });
std::vector<std::vector<std::size_t>> vec_dims(out_elements); auto ind = calc_neighbor_points(
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); vvv_ind, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s);
auto ind_lens = out_lens; auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim); ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens}; shape ind_s{shape::int32_type, ind_lens};
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
...@@ -34,16 +34,65 @@ namespace onnx { ...@@ -34,16 +34,65 @@ namespace onnx {
struct parse_slice : op_parser<parse_slice> struct parse_slice : op_parser<parse_slice>
{ {
std::vector<op_desc> operators() const { return {{"Slice"}}; } std::vector<op_desc> operators() const { return {{"Slice"}}; }
struct slice_desc
{
op::slice op;
std::vector<instruction_ref> op_args;
std::vector<int64_t> steps;
std::vector<int64_t> raxes;
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
migraphx::argument arg_value = arg->eval();
if(arg_value.empty())
{
op_args.insert(op_args.begin(), arg);
}
else
{
arg_value.visit([&](auto s) { result.assign(s.begin(), s.end()); });
}
return result;
}
};
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
op::slice op; auto sd = construct_slice_desc(parser, info, args);
auto ins = info.add_instruction(sd.op, sd.op_args);
if(not sd.raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", sd.raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(sd.steps.begin(),
sd.steps.end(),
std::back_inserter(nsteps),
[](auto s) { return std::abs(s); });
return ins = info.add_instruction(
make_op("step", {{"axes", sd.op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
}
std::vector<int64_t> steps; slice_desc construct_slice_desc(const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
slice_desc sd;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice. // to decide whether MIGRAPHX can handle this slice.
...@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice> ...@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { sd.steps.assign(s.begin(), s.end()); });
} }
if(args.size() >= 4) if(args.size() >= 4)
{ {
migraphx::argument axes_arg = args.at(3)->eval(); sd.op.axes = sd.insert(args.at(3));
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "axes")) else if(contains(info.attributes, "axes"))
{ {
literal s = parser.parse_value(info.attributes.at("axes")); literal s = parser.parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.axes)); });
} }
if(args.size() >= 3) if(args.size() >= 3)
{ {
migraphx::argument end_arg = args.at(2)->eval(); sd.op.ends = sd.insert(args.at(2));
check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "ends")) else if(contains(info.attributes, "ends"))
{ {
literal s = parser.parse_value(info.attributes.at("ends")); literal s = parser.parse_value(info.attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.ends)); });
} }
if(args.size() >= 2) if(args.size() >= 2)
{ {
migraphx::argument start_arg = args.at(1)->eval(); sd.op.starts = sd.insert(args.at(1));
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "starts")) else if(contains(info.attributes, "starts"))
{ {
literal s = parser.parse_value(info.attributes.at("starts")); literal s = parser.parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.starts)); });
} }
// data input argument
sd.always_insert(args.at(0));
// If axes arg is not given, the default is all of them. // If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(sd.op.axes.empty() and sd.op_args.size() < 3)
{ {
std::vector<int64_t> axes(args[0]->get_shape().ndim()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; sd.op.axes = axes;
} }
std::vector<int64_t> raxes; if(not sd.steps.empty())
{
if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported");
if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
}
assert(steps.empty() or steps.size() == op.axes.size()); assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(sd.steps.size()))
{ {
if(steps[i] >= 0) if(sd.steps[i] >= 0)
continue; continue;
op.starts[i] += 1; sd.op.starts[i] += 1;
if(op.starts[i] == 0) if(sd.op.starts[i] == 0)
op.starts[i] = INT_MAX; sd.op.starts[i] = INT_MAX;
op.ends[i] += 1; sd.op.ends[i] += 1;
raxes.push_back(op.axes[i]); sd.raxes.push_back(sd.op.axes[i]);
std::swap(op.starts[i], op.ends[i]); std::swap(sd.op.starts[i], sd.op.ends[i]);
}
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
} }
// If any steps are other than default 1, add a "steps" op return sd;
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const ...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
mpm.run_pass(simplify_reshapes{}); // loop to further optimize after initial transformations
mpm.run_pass(simplify_algebra{}); for(int j = 0; j < 2; j++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
}
mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{}); mpm.run_pass(propagate_constant{});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx, ...@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx,
} }
} }
/**
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
* calculate the padding value in each dimension.
*
*/
std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens, std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& wei_lens, const std::vector<std::size_t>& wei_lens,
const std::vector<std::size_t>& strides, const std::vector<std::size_t>& strides,
...@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input ...@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
assert(input_lens.size() >= 3); assert(input_lens.size() >= 3);
assert(input_lens.size() == wei_lens.size());
std::size_t num_spatial_dims = input_lens.size() - 2; std::size_t num_spatial_dims = input_lens.size() - 2;
padding.resize(2 * num_spatial_dims); padding.resize(2 * num_spatial_dims);
for(std::size_t i = 0; i < num_spatial_dims; i++) for(std::size_t i = 0; i < num_spatial_dims; i++)
...@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input ...@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
return padding; return padding;
} }
/**
* Calculate the correct output shape for a convolution with
* a given input size and other parameters.
*
*/
shape compute_padded_shape(const shape& input, shape compute_padded_shape(const shape& input,
const shape& weights, const shape& weights,
const std::vector<std::size_t>& padding, const std::vector<std::size_t>& padding,
...@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input, ...@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input,
return input.with_lens(output_lens); return input.with_lens(output_lens);
} }
/**
* Calculate the correct output shape for a pooling with
* a given input size and other parameters. This uses
* the same formula for pooling that compute_padded_shape() uses
* for convolutions, but takes slightly different inputs.
*
*/
shape compute_padded_pool_shape(const shape& input,
const shape& kernel,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation)
{
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], input.lens()[1]};
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; ++i)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (kernel.lens()[i] - 1)) + padding_factor) /
stride[i] +
1)));
}
return input.with_lens(output_lens);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots // Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots; std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules(); auto mods = this->get_modules();
for(auto* mod : mods) for(const auto* mod : mods)
{ {
for(const auto& ins : *mod) for(const auto& ins : *mod)
{ {
...@@ -347,7 +347,7 @@ void program::finalize() ...@@ -347,7 +347,7 @@ void program::finalize()
template <class T> template <class T>
std::string classify(T x) std::string classify(T x)
{ {
switch(std::fpclassify(x)) switch(std::fpclassify(static_cast<double>(x)))
{ {
case FP_INFINITE: return "inf"; case FP_INFINITE: return "inf";
case FP_NAN: return "nan"; case FP_NAN: return "nan";
...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) { ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
auto& ctx = contexts[ins->get_target_id()]; const auto& ctx = contexts[ins->get_target_id()];
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
...@@ -624,7 +624,7 @@ std::string get_migraphx_version() ...@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file. if any changes occur to the format of the MXR file.
*/ */
const int program_file_version = 6; const int program_file_version = 7;
value program::to_value() const value program::to_value() const
{ {
...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod, ...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std::back_inserter(module_inputs), std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); }); [&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs) for(const auto& smod : module_inputs)
{ {
mod_from_val(smod, v, instructions, map_mods); mod_from_val(smod, v, instructions, map_mods);
} }
...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules() ...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std::vector<module*> unused; std::vector<module*> unused;
generic_get_unused_modules( generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused)); impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused) for(const auto* m : unused)
this->remove_module(m->name()); this->remove_module(m->name());
} }
......
This diff is collapsed.
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp) add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include) target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx) target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include) rocm_install_targets(TARGETS migraphx_py INCLUDE include)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment