Commit cd4ab535 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 3891ee58 a0fa3742
...@@ -172,6 +172,22 @@ struct vector_stream ...@@ -172,6 +172,22 @@ struct vector_stream
} }
}; };
struct writer_stream
{
std::function<void(const char*, std::size_t)> writer;
writer_stream& write(const char* b, std::size_t n)
{
writer(b, n);
return *this;
}
};
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
{
writer_stream ws{std::move(writer)};
msgpack::pack(ws, v);
}
std::vector<char> to_msgpack(const value& v) std::vector<char> to_msgpack(const value& v)
{ {
vector_stream vs; vector_stream vs;
......
...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const std::vector<std::size_t>& lens,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = lens.size();
...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -39,6 +39,20 @@ ...@@ -39,6 +39,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -135,6 +149,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -135,6 +149,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return this->add_common_op(op_name, arg0, arg1); return this->add_common_op(op_name, arg0, arg1);
} }
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const std::vector<instruction_ref> inputs) const
{ {
...@@ -300,7 +333,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -300,7 +333,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else if(map_dyn_input_dims.count(name) > 0) else if(map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)}; s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name));
} }
else else
{ {
...@@ -491,7 +524,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -491,7 +524,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
...@@ -503,16 +536,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -503,16 +536,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
return {shape_type}; return {shape_type};
} }
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) return shape_from_dyn_dims(shape_type, dynamic_dims);
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,10 +21,14 @@ ...@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -32,62 +36,112 @@ namespace onnx { ...@@ -32,62 +36,112 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); bool dyn_input = x->get_shape().dynamic();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); // Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); // inputs may be either static or dynamic.
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto l1 = info.add_common_op("sub", x, mean);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); // for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}}); // reduce_sum to calculate variance i.e.
auto epsilon_bcast = // var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); std::string reduce_op_name =
auto variance_bcast = (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance); if(dtype == shape::half_type and not convert_fp16)
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); {
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_common_op("sqdiff", x, mean);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto l2 = info.add_common_op("add", variance, epsilon_literal);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_common_op("mul", l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); // add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
; // both scale and bias.
auto bias_bcast = instruction_ref scale_bcast;
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); instruction_ref bias_bcast;
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); if(dyn_input)
return info.add_instruction(make_op("add"), l5, bias_bcast); {
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
}
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -33,8 +33,7 @@ namespace onnx { ...@@ -33,8 +33,7 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = { std::set<shape::type_t> float_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; } std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; } std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
...@@ -31,6 +31,8 @@ namespace migraphx { ...@@ -31,6 +31,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where> struct parse_where : op_parser<parse_where>
{ {
std::vector<op_desc> operators() const { return {{"Where"}}; } std::vector<op_desc> operators() const { return {{"Where"}}; }
...@@ -56,6 +58,14 @@ struct parse_where : op_parser<parse_where> ...@@ -56,6 +58,14 @@ struct parse_where : op_parser<parse_where>
auto lens = auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -86,14 +86,24 @@ struct module_pm : module_pass_manager ...@@ -86,14 +86,24 @@ struct module_pm : module_pass_manager
assert(mod); assert(mod);
return *mod; return *mod;
} }
virtual module* create_module(const std::string& name) override virtual module* create_module(const std::string& name) override
{ {
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
trace("Pass: ", p.name());
assert(mod); assert(mod);
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{})) if(enabled(MIGRAPHX_TIME_PASSES{}))
...@@ -113,27 +123,18 @@ struct module_pm : module_pass_manager ...@@ -113,27 +123,18 @@ struct module_pm : module_pass_manager
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); } module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
std::unordered_set<module_ref> visited; std::unordered_set<module_ref> visited;
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto tree = prog.get_module_tree();
auto tree = prog.get_module_tree(); std::vector<module_ref> sub_mods = root_mod->get_sub_modules();
sub_mods.insert(sub_mods.begin(), root_mod);
visited.clear(); visited.clear();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(sub_mods))
{ {
if(mod->bypass()) if(mod->bypass())
continue; continue;
...@@ -157,5 +158,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -157,5 +158,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
} }
} }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
run_passes(prog, prog.get_main_module(), passes, trace);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os) ...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return [&](const char* x) { os << x; }; return [&](const char* x) { os << x; };
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out) template <class F>
int exec(const std::string& cmd, const char* type, F f)
{ {
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
}; };
{ {
// TODO: Use execve instead of popen // TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
if(not pipe) if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer; f(pipe.get());
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
} }
return ec; return ec;
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
return exec(cmd, "r", [&](FILE* f) {
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
std_out(buffer.data());
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
{
return exec(cmd, "w", [&](FILE* f) {
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
});
}
struct process_impl struct process_impl
{ {
std::string command{}; std::string command{};
...@@ -72,6 +87,15 @@ struct process_impl ...@@ -72,6 +87,15 @@ struct process_impl
result += command; result += command;
return result; return result;
} }
template <class... Ts>
void check_exec(Ts&&... xs) const
{
int ec = migraphx::exec(std::forward<Ts>(xs)...);
if(ec != 0)
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
std::to_string(ec));
}
}; };
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>()) process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p) ...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return *this; return *this;
} }
void process::exec() void process::exec() { impl->check_exec(impl->get_command(), redirect_to(std::cout)); }
void process::write(std::function<void(process::writer)> pipe_in)
{ {
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout)); impl->check_exec(impl->get_command(), std::move(pipe_in));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/compile_options.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -42,6 +43,7 @@ ...@@ -42,6 +43,7 @@
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <unordered_map>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
...@@ -53,12 +55,24 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -53,12 +55,24 @@ inline namespace MIGRAPHX_INLINE_NS {
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
struct mark_instruction_target
{
std::size_t target_id = 0;
std::string name() const { return "mark_instruction_target"; }
void apply(module& m) const
{
for(auto& ins : m)
ins.set_target_id(target_id);
}
};
struct program_impl struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules; std::unordered_map<std::string, module> modules;
context ctx; context ctx;
std::string target_name; std::string target_name;
std::vector<context> contexts;
}; };
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
...@@ -205,8 +219,94 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -205,8 +219,94 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{
// Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules();
for(auto* mod : mods)
{
for(const auto& ins : *mod)
{
if(ins.name() != "run_on_target")
continue;
auto v = ins.get_operator().to_value();
module_ref root = ins.module_inputs().front();
std::size_t root_target_id = v.at("target_id").to<std::size_t>();
assert(root_target_id < targets.size());
roots.insert({root_target_id, root});
}
}
auto trace = tracer{};
// TODO: Add tracer based on compile options
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
trace();
// It is assumed that all instructions outside of any root module would run on "ref" target
// Ref target may or may not be passed as one of the target for the "compile()".
// If it is not passed, Create one and add context of it into the map.
auto target_idx = [&](const std::string& t_name) {
return static_cast<std::size_t>(
std::find_if(
targets.begin(), targets.end(), [&](const auto& t) { return t.name() == t_name; }) -
targets.begin());
};
std::size_t ref_target_id = target_idx("ref");
if(ref_target_id == targets.size())
{
this->impl->contexts.resize(targets.size() + 1);
this->impl->contexts[ref_target_id] = migraphx::make_target("ref").get_context();
// users could pass lessers compile_ops than targets, in that case use default compile_opts
compile_opts.resize(targets.size() + 1, migraphx::compile_options{});
}
else
{
this->impl->contexts.resize(targets.size());
compile_opts.resize(targets.size(), migraphx::compile_options{});
}
// mark all the instruction as ref target first, later change target_id based on root-target
run_passes(*this, {mark_instruction_target{ref_target_id}});
// Run passes on each root target
for(const auto i : range(targets.size()))
{
const auto& root_target = targets.at(i);
auto root_target_id = i;
auto root_modules_range = roots.equal_range(root_target_id);
this->impl->contexts[root_target_id] = root_target.get_context();
for(const auto& [id, current_mod] : range(root_modules_range))
{
auto passes = root_target.get_passes(this->impl->contexts[root_target_id],
compile_opts[root_target_id]);
passes.push_back(mark_instruction_target{static_cast<size_t>(root_target_id)});
run_passes(*this, current_mod, passes, trace);
auto invalid = current_mod->validate();
if(invalid != current_mod->end())
{
MIGRAPHX_THROW("Invalid module " + current_mod->name() +
" from compilation at instruction " +
std::to_string(std::distance(current_mod->begin(), invalid)));
}
auto dangling = current_mod->find_dangling_reference();
if(dangling != current_mod->end())
{
auto index = std::distance(current_mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index));
}
current_mod->finalize(this->impl->contexts[root_target_id]);
}
}
}
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
{ {
// todo: combine with multi-target compile method
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->target_name = t.name();
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
...@@ -331,7 +431,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -331,7 +431,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
// TODO: may want to check correct number of dimensions and/or was within bounds // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().any_of_dynamic() and
param.get_shape() != ins->get_shape())
{ {
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name + "} for parameter: " + param_name +
...@@ -365,7 +466,6 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -365,7 +466,6 @@ std::vector<argument> generic_eval(const module* mod,
assert(results.find(i) != results.end()); assert(results.find(i) != results.end());
return results[i]; return results[i];
}); });
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
...@@ -439,39 +539,53 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -439,39 +539,53 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, ret = generic_eval(
ctx, *this,
std::move(params), ctx,
with_check_context([&](auto& ins, auto f, auto&& check_context) { std::move(params),
ctx.finish(); with_check_context([&](auto& ins, auto f, auto&& check_context) {
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; ctx.finish();
timer t{}; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
auto result = check_context(f); timer t{};
double t1 = t.record<milliseconds>(); auto result = check_context(f);
ctx.finish(); double t1 = t.record<milliseconds>();
double t2 = t.record<milliseconds>(); ctx.finish();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; double t2 = t.record<milliseconds>();
if(trace_level > 1 and ins->name().front() != '@' and std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
ins->name() != "load" and not result.empty()) if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
{ not result.empty())
target tgt = make_target(this->impl->target_name); {
auto buffer = tgt.copy_from(result); migraphx::argument buffer;
if(trace_level == 2) try
{ {
std::cout << "Output has " target tgt = make_target(this->impl->target_name);
<< to_string_range(classify_argument(buffer)) buffer = tgt.copy_from(result);
<< std::endl; }
std::cout << "Output: "; catch(const migraphx::exception&)
preview_argument(std::cout, buffer); {
std::cout << std::endl; // instruction was run on host then no need to copy buffer from target
} buffer = result;
else }
{ catch(...)
std::cout << "Output: " << buffer << std::endl; {
} MIGRAPHX_THROW(
} "MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
return result; }
})); if(trace_level == 2)
{
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
} }
else else
{ {
...@@ -860,7 +974,9 @@ void program::print_py(std::ostream& os) const ...@@ -860,7 +974,9 @@ void program::print_py(std::ostream& os) const
os << "p = migraphx.program()\n"; os << "p = migraphx.program()\n";
for(auto& mod : vec_modules) for(auto& mod : vec_modules)
{ {
std::string var_name = "m" + mod->name(); std::string var_name = "m";
if(mod->name() != "main")
var_name += mod->name();
os << var_name << " = "; os << var_name << " = ";
if(mod->name() == "main") if(mod->name() == "main")
os << "p.get_main_module()"; os << "p.get_main_module()";
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
return;
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
auto new_lit = root_module->add_literal(ins->get_literal());
for(auto out_ins : ins->outputs())
{
out_ins->replace_argument(out_ins, ins, new_lit);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -27,11 +27,14 @@ ...@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const ...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{ {
if(not literals[i].empty()) if(not literals[i].empty())
{ {
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l); m.replace_instruction(const_instrs_vec[i], l);
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
...@@ -63,6 +62,7 @@ namespace py = pybind11; ...@@ -63,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \ PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx { namespace migraphx {
migraphx::value to_value(py::kwargs kwargs); migraphx::value to_value(py::kwargs kwargs);
...@@ -95,6 +95,10 @@ void visit_py(T x, F f) ...@@ -95,6 +95,10 @@ void visit_py(T x, F f)
{ {
f(x.template cast<std::string>()); f(x.template cast<std::string>());
} }
else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
{
f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
}
else else
{ {
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!"); MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
...@@ -165,7 +169,10 @@ template <class T> ...@@ -165,7 +169,10 @@ template <class T>
py::buffer_info to_buffer_info(T& x) py::buffer_info to_buffer_info(T& x)
{ {
migraphx::shape s = x.get_shape(); migraphx::shape s = x.get_shape();
auto strides = s.strides(); assert(s.type() != migraphx::shape::tuple_type);
if(s.dynamic())
MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
auto strides = s.strides();
std::transform( std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); }); strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b; py::buffer_info b;
...@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(), b = py::buffer_info(x.data(),
as.size(), as.size(),
py::format_descriptor<bool>::format(), py::format_descriptor<bool>::format(),
s.lens().size(), s.ndim(),
s.lens(), s.lens(),
strides); strides);
} }
...@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(), b = py::buffer_info(x.data(),
as.size(), as.size(),
py::format_descriptor<decltype(as())>::format(), py::format_descriptor<decltype(as())>::format(),
s.lens().size(), s.ndim(),
s.lens(), s.lens(),
strides); strides);
} }
...@@ -236,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -236,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float")); auto t = migraphx::shape::parse_type(v.get("type", "float"));
if(v.contains("dyn_dims"))
{
auto dyn_dims =
migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
v.at("dyn_dims"));
return migraphx::shape(t, dyn_dims);
}
auto lens = v.get<std::size_t>("lens", {1}); auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides")) if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>()); return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
...@@ -249,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -249,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string) .def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted) .def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard) .def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar) .def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{}) .def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{}) .def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
.def(py::init<>())
.def(py::init<std::size_t, std::size_t>())
.def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) { .def(py::init([](py::buffer b) {
...@@ -283,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -283,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref"); py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module") py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
...@@ -434,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -434,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx", "parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value, unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
...@@ -448,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -448,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 0,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10);
...@@ -458,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -458,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer", "parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
unsigned int default_dim_value, unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error) { bool print_program_on_error) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options); return migraphx::parse_onnx_buffer(onnx_buffer, options);
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 0,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false); py::arg("print_program_on_error") = false);
......
...@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(x->get_shape().type() != y_scale->get_shape().type()) if(x->get_shape().type() != y_scale->get_shape().type())
{ {
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x); x = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto zero_point = m.insert_instruction( auto zero_point =
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); m.insert_instruction(ins,
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
...@@ -72,14 +75,16 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -72,14 +75,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void apply_dequantizelinear(module& m, instruction_ref ins) void apply_dequantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "dequantizelinear"); assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1]; auto x_scale = ins->inputs()[1];
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto x_zero_point = m.insert_instruction( auto x_zero_point =
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); m.insert_instruction(ins,
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point); x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment