Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
...@@ -458,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -458,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->push_back({builtin::returns{}, {}, std::move(args)}); shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const ...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
void module::finalize(context& ctx) void module::finalize(std::vector<context>& contexts)
{ {
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
...@@ -660,10 +663,10 @@ void module::finalize(context& ctx) ...@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: "; std::cout << "Finalize: ";
this->debug_print(ins); this->debug_print(ins);
} }
ins->finalize(ctx); ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
smod->finalize(ctx); smod->finalize(contexts);
} }
} }
...@@ -723,15 +726,15 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -723,15 +726,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
std::string var_name; std::string var_name;
if(not this->name().empty() and this->name() != "main")
var_name = this->name() + ":";
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->get_operator()).parameter; var_name.append(any_cast<builtin::param>(ins->get_operator()).parameter);
} }
else else
{ {
var_name = this->name(); var_name.append("@" + std::to_string(count));
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
} }
// count every instruction so index matches loc in the printout program // count every instruction so index matches loc in the printout program
count++; count++;
...@@ -795,7 +798,10 @@ static std::string to_c_id(const std::string& name, char rep = '_') ...@@ -795,7 +798,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static std::string cpp_var_name(const std::string& name) static std::string cpp_var_name(const std::string& name)
{ {
return to_c_id("x_" + replace_string(name, ":", "_module_")); std::string prefix = "x_";
if(not contains(name, "@"))
prefix = "p_";
return to_c_id(prefix + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op) static void print_py_op(std::ostream& os, const operation& op)
...@@ -875,7 +881,7 @@ module::print_py(std::ostream& os, ...@@ -875,7 +881,7 @@ module::print_py(std::ostream& os,
use_abs = false; use_abs = false;
if(use_abs) if(use_abs)
os << "migraphx.abs_literal("; os << "migraphx.abs_literal(";
os << "migraphx.generate_literal("; os << "migraphx.generate_argument(";
print_py_shape(os, ins->get_shape()); print_py_shape(os, ins->get_shape());
os << ", " << seed << ")"; os << ", " << seed << ")";
if(use_abs) if(use_abs)
...@@ -1005,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const ...@@ -1005,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
module& module::sort() module& module::sort()
{ {
auto implicit_deps = calc_implicit_deps();
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) auto ins_inputs = ins->inputs();
if(implicit_deps.find(ins) != implicit_deps.end())
{
auto ins_implict_inputs = implicit_deps.at(ins);
ins_inputs.insert(
ins_inputs.end(), ins_implict_inputs.begin(), ins_implict_inputs.end());
}
for(auto child : ins_inputs)
{ {
if(not contains(this->impl->instructions, child)) if(not contains(this->impl->instructions, child))
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,18 +35,21 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,18 +35,21 @@ inline namespace MIGRAPHX_INLINE_NS {
* vec: the vector attribute to normalize * vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise * axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] = * val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling * value::array{normalize_attribute::include_min};
* normalize_attributes(op&, lens) * input_shape: input shape passed when calling
* normalize_attributes(op&, input_shape)
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const shape& input_shape,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
{ {
...@@ -54,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -54,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec,
} }
std::vector<int64_t> max_vals(vec.size(), n_rank); std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len)) if(contains(vec_attrs, op::normalize_attribute::use_len))
{ {
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; }); if(input_shape.dynamic())
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
});
}
else
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.lens().at(i);
});
}
} }
if(contains(vec_attrs, op::normalize_attribute::clip_max)) if(contains(vec_attrs, op::normalize_attribute::clip_max))
...@@ -84,14 +106,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -84,14 +106,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -124,14 +146,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -124,14 +146,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -157,9 +179,9 @@ auto tune_pad_attribute(const value& val) ...@@ -157,9 +179,9 @@ auto tune_pad_attribute(const value& val)
/** /**
* Assumptions: * Assumptions:
* Dimensions to pad start from the third dimension (index 2). * Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input. * Called by compute_shape_op() with the shape of the first input.
*/ */
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const shape& input_shape)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
...@@ -170,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -170,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto padding_size = padding.size(); auto padding_size = padding.size();
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true; tuned = true;
else if(padding_size != (lens.size() - padding_start)) else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); MIGRAPHX_THROW("inconsistent padding size");
else else
{ {
...@@ -193,7 +215,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -193,7 +215,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -202,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -202,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -211,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -211,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const ...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const
auto s = inputs[0]->get_shape(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, s.max_lens())) if(normalize_attributes(tuned_op, s))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -30,10 +30,11 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -30,10 +30,11 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB ONNX_SRCS CONFIGURE_DEPENDS *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS}) add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include) target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
......
...@@ -38,9 +38,24 @@ ...@@ -38,9 +38,24 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER)
static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
...@@ -135,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -135,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return this->add_common_op(op_name, arg0, arg1); return this->add_common_op(op_name, arg0, arg1);
} }
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const std::vector<instruction_ref> inputs) const
{ {
...@@ -264,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -264,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
std::vector<instruction_ref> void print_added_instructions(module* mod,
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& result)
{
// Print instructions added by the parser not in args
std::vector<instruction_ref> added_instructions;
fix([&](auto self, auto r) {
for(auto ins : r)
{
if(contains(args, ins))
continue;
if(contains(added_instructions, ins))
continue;
self(ins->inputs());
added_instructions.push_back(ins);
}
})(result);
mod->debug_print(added_instructions);
}
std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod // backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
} }
return mod_insts;
}
std::unordered_map<std::string, instruction_ref>
parse_inputs(const onnx_parser& parser,
module* mod,
const onnx::GraphProto& graph,
std::unordered_map<std::string, instruction_ref> mod_insts)
{
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
...@@ -284,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -284,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name)) if(contains(parser.instructions, name))
{ {
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!"); "\" existing in parent graph!");
...@@ -292,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -292,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
shape s; shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(parser.map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = parser.map_input_dims.at(name);
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)}; s = shape_from_dyn_dims(shape_type, parser.map_dyn_input_dims.at(name));
} }
else else
{ {
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
return mod_insts;
}
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts =
parse_intializer(*this, mod, graph);
mod_insts = parse_inputs(*this, mod, graph, mod_insts);
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "operator: " << node.op_type() << std::endl;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
...@@ -351,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -351,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result.begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
{
print_added_instructions(mod, args, result);
}
} }
// Find instructions corresponding to the output // Find instructions corresponding to the output
...@@ -503,16 +587,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -503,16 +587,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
return {shape_type}; return {shape_type};
} }
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) return shape_from_dyn_dims(shape_type, dynamic_dims);
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size(); auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2) if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_rank > 2) else if(x_rank > 2)
...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction( auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction( auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
else else
......
...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector) ...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
return output_vector; return output_vector;
} }
struct parse_deconvolution : op_parser<parse_deconvolution> struct parse_conv_transpose : op_parser<parse_conv_transpose>
{ {
std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; } std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; }
...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
operation op = make_op("deconvolution"); operation op = make_op("convolution_backwards");
value values = op.to_value(); value values = op.to_value();
// op::deconvolution op; auto l0 = args[0];
auto l0 = args[0];
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
bool asym_padding = false; bool asym_padding = false;
auto in_lens = l0->get_shape().lens(); assert(l0->get_shape().ndim() > 2);
assert(in_lens.size() > 2); auto kdims = l0->get_shape().ndim() - 2;
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV_TRANSPOSE"); check_padding_mode(info, "CONV_TRANSPOSE");
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
asym_padding = is_asym_padding(padding); asym_padding = is_asym_padding(padding);
size_t pad_ndims = padding.size() / 2;
if(not asym_padding) if(not asym_padding)
{ {
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings"); check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
values["padding"].clear(); values["padding"].clear();
std::transform(padding.begin(), std::transform(padding.begin(),
...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
std::back_inserter(values["padding"]), std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; }); [](auto pad_val) { return pad_val; });
} }
else if(l0->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: asymmetric padding (padding_L != padding_R) "
"not supported with dynamic shapes");
}
else
{
// set padding to 0s, asym_padding handled by parser with slice
// TODO changing parser and op to do asym padding in op
values["padding"] = std::vector<std::size_t>(pad_ndims, 0);
}
} }
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes( check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides"); kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
values["dilation"].clear(); values["dilation"].clear();
...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
} }
// TODO: auto padding needs to be implemented for this parser and operator // TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes.at("auto_pad").s()) != "NOTSET")
{ {
auto s = info.attributes["auto_pad"].s(); MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto padding not supported");
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
"simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
bool is_same_upper = (s.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
} }
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
recalc_conv_attributes(values, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]); auto l1 = info.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding) if(asym_padding)
{ {
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims
...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1); make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
} }
if(contains(info.attributes, "output_padding")) // TODO, should check output_padding < (strides or dilations)
if(contains(info.attributes, "output_padding") and
not contains(info.attributes, "output_shape"))
{ {
size_t non_kdims = dims.size() * 2 - kdims; size_t non_kdims = l1->get_shape().ndim() * 2 - kdims;
std::vector<int64_t> output_padding(non_kdims, 0); std::vector<int64_t> output_padding(non_kdims, 0);
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding)); copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
check_attr_sizes(kdims, check_attr_sizes(kdims,
...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1); l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
} }
// TODO, doing unnecessary calcuations with this. Could instead
// calculate the padding to conv_transpose that would give the output_shape.
if(contains(info.attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
{ {
if(l1->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: output_shape attribute and dynamic shapes "
"not supported");
}
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape)); copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
check_attr_sizes( check_attr_sizes(
kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape"); kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
dims = to_int64_vector(l1->get_shape().lens());
copy(dims.begin() + 2, dims.end(), curr_shape.begin());
if(curr_shape != output_shape) if(curr_shape != output_shape)
{ {
std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0); std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,10 +21,14 @@ ...@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -32,62 +36,117 @@ namespace onnx { ...@@ -32,62 +36,117 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); // Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); // inputs may be either static or dynamic.
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto l1 = info.add_common_op("sub", x, mean);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); // for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}}); // reduce_sum to calculate variance i.e.
auto epsilon_bcast = // var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); std::string reduce_op_name =
auto variance_bcast = (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance); if(dtype == shape::half_type and not convert_fp16)
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); {
if(x->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_INSTANCENORM: half type not supported with dynamic shape "
"unless convert_fp16 is TRUE");
}
auto dims = x->get_shape().lens();
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_common_op("sqdiff", x, mean);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto l2 = info.add_common_op("add", variance, epsilon_literal);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_common_op("mul", l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); // add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
; // both scale and bias.
auto bias_bcast = instruction_ref scale_bcast;
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); instruction_ref bias_bcast;
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); if(x->get_shape().dynamic())
return info.add_instruction(make_op("add"), l5, bias_bcast); {
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
auto dims = x->get_shape().lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
}
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -33,8 +33,7 @@ namespace onnx { ...@@ -33,8 +33,7 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = { std::set<shape::type_t> float_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; } std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; } std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
......
...@@ -30,8 +30,11 @@ namespace migraphx { ...@@ -30,8 +30,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
// Use a literal instruction to replace the shape since, output of /**
// shape operator are literals in migraphx * If static shape input, creates a literal in migraphx.
* If dynamic shape input, creates a dimensions_of operator in migraphx (runtime evaluation of
* shape).
*/
struct parse_shape : op_parser<parse_shape> struct parse_shape : op_parser<parse_shape>
{ {
std::vector<op_desc> operators() const { return {{"Shape"}}; } std::vector<op_desc> operators() const { return {{"Shape"}}; }
...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape> ...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape>
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); auto input_shape = args[0]->get_shape();
std::vector<int64_t> vec_shape(arg_shape.size()); int input_ndim = input_shape.ndim();
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); std::size_t start = 0;
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { std::size_t end = input_ndim;
return int64_t(i); // Normalizing the start and end is handled here because of how the static shape version
}); // works. Clamping to [-r, r], where r is ndim of input and then making positive.
return info.add_literal(migraphx::literal{s, vec_shape}); auto normalize_ind = [&](int64_t ind) {
if(ind < (-1 * input_ndim))
{
ind = -1 * input_ndim;
}
if(ind > input_ndim)
{
ind = input_ndim;
}
return (ind >= 0) ? ind : input_ndim + ind;
};
if(contains(info.attributes, "end"))
{
end = normalize_ind(info.attributes.at("end").i());
}
if(contains(info.attributes, "start"))
{
start = normalize_ind(info.attributes.at("start").i());
}
if(end <= start)
{
MIGRAPHX_THROW("PARSE_SHAPE: ending axis <= starting axis, end: " +
std::to_string(end) + " start: " + std::to_string(start));
}
if(input_shape.dynamic())
{
return info.add_instruction(make_op("dimensions_of", {{"start", start}, {"end", end}}),
args[0]);
}
else
{
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input_shape.lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
return info.add_literal(migraphx::literal{s, vec_shape});
}
} }
}; };
......
...@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where> ...@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
auto lens = auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod = nullptr; module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr; tracer* t = nullptr;
module* common_parent = nullptr; module* common_parent = nullptr;
program* prog = nullptr; program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {} module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
{ {
...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager ...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override virtual module* get_root_module() override
{ {
if(root_mod != nullptr)
return root_mod;
assert(prog); assert(prog);
return prog->get_main_module(); return prog->get_main_module();
} }
...@@ -123,33 +131,24 @@ struct module_pm : module_pass_manager ...@@ -123,33 +131,24 @@ struct module_pm : module_pass_manager
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); } module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
std::unordered_set<module_ref> visited; std::unordered_set<module_ref> visited;
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto tree = prog.get_module_tree();
auto tree = prog.get_module_tree(); std::vector<module_ref> sub_mods = root_mod->get_sub_modules();
sub_mods.insert(sub_mods.begin(), root_mod);
visited.clear(); visited.clear();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(sub_mods))
{ {
if(mod->bypass()) if(mod->bypass())
continue; continue;
if(not visited.insert(mod).second) if(not visited.insert(mod).second)
continue; continue;
module_pm mpm{mod, &trace}; module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog; mpm.prog = &prog;
auto parents = range(tree.equal_range(mod)); auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents); auto nparents = distance(parents);
...@@ -167,5 +166,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -167,5 +166,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
} }
} }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
run_passes(prog, prog.get_main_module(), passes, trace);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes) ...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first; return it->first;
} }
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/version.h>
#include <migraphx/compile_options.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -38,12 +40,14 @@ ...@@ -38,12 +40,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp> #include <migraphx/supported_segments.hpp>
#include <iostream> #include <iostream>
#include <queue>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <unordered_map>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
#include <cassert> #include <cassert>
...@@ -53,12 +57,23 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -53,12 +57,23 @@ inline namespace MIGRAPHX_INLINE_NS {
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
struct mark_instruction_target
{
std::size_t target_id = 0;
std::string name() const { return "mark_instruction_target"; }
void apply(module& m) const
{
for(auto& ins : m)
ins.set_target_id(target_id);
}
};
struct program_impl struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules; std::unordered_map<std::string, module> modules;
context ctx; std::vector<context> contexts;
std::string target_name; std::vector<target> targets;
}; };
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
...@@ -82,14 +97,8 @@ void program::assign(const program& p) ...@@ -82,14 +97,8 @@ void program::assign(const program& p)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx; *impl = *p.impl;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
// build a map from old ins to new ins // build a map from old ins to new ins
// Build a map from old module to new module // Build a map from old module to new module
...@@ -152,7 +161,11 @@ std::vector<shape> program::get_output_shapes() const ...@@ -152,7 +161,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes(); return mm->get_output_shapes();
} }
context& program::get_context() const { return impl->ctx; } context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -203,20 +216,106 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -203,20 +216,106 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p; return p;
} }
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{
// Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules();
for(auto* mod : mods)
{
for(const auto& ins : *mod)
{
if(ins.name() != "run_on_target")
continue;
auto v = ins.get_operator().to_value();
module_ref root = ins.module_inputs().front();
std::size_t root_target_id = v.at("target_id").to<std::size_t>();
assert(root_target_id < targets.size());
roots.insert({root_target_id, root});
}
}
auto trace = tracer{};
// TODO: Add tracer based on compile options
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
trace();
// It is assumed that all instructions outside of any root module would run on "ref" target
// Ref target may or may not be passed as one of the target for the "compile()".
// If it is not passed, Create one and add context of it into the map.
auto target_idx = [&](const std::string& t_name) {
return static_cast<std::size_t>(
std::find_if(
targets.begin(), targets.end(), [&](const auto& t) { return t.name() == t_name; }) -
targets.begin());
};
std::size_t ref_target_id = target_idx("ref");
if(ref_target_id == targets.size())
{
this->impl->contexts.resize(targets.size() + 1);
this->impl->contexts[ref_target_id] = migraphx::make_target("ref").get_context();
// users could pass lessers compile_ops than targets, in that case use default compile_opts
compile_opts.resize(targets.size() + 1, migraphx::compile_options{});
}
else
{
this->impl->contexts.resize(targets.size());
compile_opts.resize(targets.size(), migraphx::compile_options{});
}
// mark all the instruction as ref target first, later change target_id based on root-target
run_passes(*this, {mark_instruction_target{ref_target_id}});
// Run passes on each root target
for(const auto i : range(targets.size()))
{
const auto& root_target = targets.at(i);
auto root_target_id = i;
auto root_modules_range = roots.equal_range(root_target_id);
this->impl->contexts[root_target_id] = root_target.get_context();
for(const auto& [id, current_mod] : range(root_modules_range))
{
auto passes = root_target.get_passes(this->impl->contexts[root_target_id],
compile_opts[root_target_id]);
passes.push_back(mark_instruction_target{static_cast<size_t>(root_target_id)});
run_passes(*this, current_mod, passes, trace);
auto invalid = current_mod->validate();
if(invalid != current_mod->end())
{
MIGRAPHX_THROW("Invalid module " + current_mod->name() +
" from compilation at instruction " +
std::to_string(std::distance(current_mod->begin(), invalid)));
}
auto dangling = current_mod->find_dangling_reference();
if(dangling != current_mod->end())
{
auto index = std::distance(current_mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index));
}
}
}
this->finalize();
}
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
{ {
// todo: combine with multi-target compile method
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->targets = {t};
this->impl->ctx = t.get_context(); this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
...@@ -235,14 +334,14 @@ void program::compile(const target& t, compile_options options) ...@@ -235,14 +334,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index));
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->contexts);
} }
} }
void program::finalize() void program::finalize()
{ {
auto* mm = this->get_main_module(); auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx); mm->finalize(this->impl->contexts);
} }
template <class T> template <class T>
...@@ -259,6 +358,31 @@ std::string classify(T x) ...@@ -259,6 +358,31 @@ std::string classify(T x)
} }
} }
void print_statistics(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", ";
os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", ";
double num_elements = t.size();
auto mean = std::accumulate(t.begin(), t.end(), 0.0) / num_elements;
auto stddev = std::sqrt(
std::accumulate(t.begin(),
t.end(),
0.0,
[&](auto r, auto v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
os << "Mean: " << mean << ", ";
os << "StdDev: " << stddev << "\n";
},
[&](const auto& xs) {
for(const auto& x : xs)
{
print_statistics(os, x);
}
});
}
std::unordered_set<std::string> classify_argument(const argument& a) std::unordered_set<std::string> classify_argument(const argument& a)
{ {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
...@@ -304,16 +428,15 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -304,16 +428,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F make_trace) F trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end()); assert(results.find(ins) == results.end());
...@@ -366,18 +489,22 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -366,18 +489,22 @@ std::vector<argument> generic_eval(const module* mod,
assert(results.find(i) != results.end()); assert(results.find(i) != results.end());
return results[i]; return results[i];
}); });
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx; return generic_eval(smod, ctx, inputs, results, trace);
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(
return ins->normalized_operator().compute( ins, trace(ins, [&] {
ctx, ins->get_shape(), values, mod_args, module_eval); auto op = ins->normalized_operator();
})); if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
}));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
if(not ins->get_shape().any_of_dynamic()) if(not ins->get_shape().any_of_dynamic())
...@@ -390,44 +517,25 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -390,44 +517,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F> template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F make_trace) F trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, trace);
} }
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& contexts = this->impl->contexts;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret; std::vector<argument> ret;
if(exec_env.async) if(exec_env.async)
{ {
ctx.wait_for(exec_env.queue); assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
} }
if(trace_level > 0) if(trace_level > 0)
...@@ -439,70 +547,93 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -439,70 +547,93 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names); instruction::print(ss, x, ins_names);
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
ret = generic_eval(*this, auto& ctx = contexts[ins->get_target_id()];
ctx, ctx.finish();
std::move(params), std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
with_check_context([&](auto& ins, auto f, auto&& check_context) { timer t{};
ctx.finish(); auto result = f();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; double t1 = t.record<milliseconds>();
timer t{}; ctx.finish();
auto result = check_context(f); double t2 = t.record<milliseconds>();
double t1 = t.record<milliseconds>(); std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
ctx.finish(); if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
double t2 = t.record<milliseconds>(); not result.empty())
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; {
if(trace_level > 1 and ins->name().front() != '@' and migraphx::argument buffer;
ins->name() != "load" and not result.empty()) try
{ {
target tgt = make_target(this->impl->target_name); const target& tgt = this->impl->targets.at(ins->get_target_id());
auto buffer = tgt.copy_from(result); buffer = tgt.copy_from(result);
if(trace_level == 2) }
{ catch(const migraphx::exception&)
std::cout << "Output has " {
<< to_string_range(classify_argument(buffer)) // instruction was run on host then no need to copy buffer from target
<< std::endl; buffer = result;
std::cout << "Output: "; }
preview_argument(std::cout, buffer); catch(...)
std::cout << std::endl; {
} MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
else }
{ if(trace_level == 2)
std::cout << "Output: " << buffer << std::endl; {
} std::cout << "Output has " << to_string_range(classify_argument(buffer))
} << std::endl;
return result; std::cout << "Output: ";
})); preview_argument(std::cout, buffer);
std::cout << std::endl;
print_statistics(std::cout, buffer);
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
});
} }
else else
{ {
ret = generic_eval(*this, ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
if(exec_env.async) if(exec_env.async)
{ {
ctx.finish_on(exec_env.queue); assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
} }
return ret; return ret;
} }
const int program_file_version = 5; void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
std::string get_migraphx_version()
{
std::stringstream ss;
ss << std::to_string(MIGRAPHX_VERSION_MAJOR) << "." << std::to_string(MIGRAPHX_VERSION_MINOR)
<< "." << std::to_string(MIGRAPHX_VERSION_PATCH);
return ss.str();
}
/*
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const int program_file_version = 6;
value program::to_value() const value program::to_value() const
{ {
value result; value result;
result["version"] = program_file_version; result["version"] = program_file_version;
result["target"] = this->impl->target_name; result["migraphx_version"] = get_migraphx_version();
if(not this->impl->target_name.empty()) result["targets"] = migraphx::to_value(this->impl->targets);
result["context"] = this->impl->ctx.to_value(); result["contexts"] = migraphx::to_value(this->impl->contexts);
value module_vals = value::object{};
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->get_modules()) for(auto& mod : this->get_modules())
{ {
...@@ -626,15 +757,27 @@ void program::from_value(const value& v) ...@@ -626,15 +757,27 @@ void program::from_value(const value& v)
auto version = v.at("version").to<int>(); auto version = v.at("version").to<int>();
if(version != program_file_version) if(version != program_file_version)
{ {
MIGRAPHX_THROW("Warning: Program version mismatch"); MIGRAPHX_THROW(
"Error: Program version mismatch. MXR file was created using program file version: " +
std::to_string(version) + ", while installed MIGraphX is using program file version: " +
std::to_string(program_file_version) +
", Try regenerating MXR file using installed MIGraphX and running again.");
}
auto migx_version = v.at("migraphx_version").to<std::string>();
if(migx_version != get_migraphx_version())
{
std::cout << "WARNING: MXR File was created using MIGraphX version: " << migx_version
<< ", while installed MIGraphX is at version: " << get_migraphx_version()
<< ", operators implementation could be mismatched.";
} }
this->impl->target_name = v.at("target").to<std::string>(); migraphx::from_value(v.at("targets"), this->impl->targets);
if(not this->impl->target_name.empty())
for(auto i : range(this->impl->targets.size()))
{ {
target t = make_target(this->impl->target_name); this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->ctx = t.get_context(); this->impl->contexts.back().from_value(v.at("contexts")[i]);
this->impl->ctx.from_value(v.at("context"));
} }
auto module_vals = v.at("modules"); auto module_vals = v.at("modules");
...@@ -655,7 +798,9 @@ void program::from_value(const value& v) ...@@ -655,7 +798,9 @@ void program::from_value(const value& v)
auto* mm = get_main_module(); auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods); mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize(); // Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize();
} }
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
...@@ -675,19 +820,19 @@ std::string perf_group(const operation& op) ...@@ -675,19 +820,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m) void program::mark(const parameter_map& params, marker&& m)
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Start marking // Start marking
m.mark_start(*this); m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
m.mark_start(ins); m.mark_start(ins);
result = f(); result = f();
m.mark_stop(ins); m.mark_stop(ins);
return result; return result;
})); });
m.mark_stop(*this); m.mark_stop(*this);
} }
...@@ -696,10 +841,10 @@ void program::perf_report(std::ostream& os, ...@@ -696,10 +841,10 @@ void program::perf_report(std::ostream& os,
parameter_map params, parameter_map params,
std::size_t batch) const std::size_t batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
...@@ -707,28 +852,28 @@ void program::perf_report(std::ostream& os, ...@@ -707,28 +852,28 @@ void program::perf_report(std::ostream& os,
{ {
total_vec.push_back(time<milliseconds>([&] { total_vec.push_back(time<milliseconds>([&] {
eval(params); eval(params);
ctx.finish(); this->finish();
})); }));
} }
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) { generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); this->impl->contexts[ins->get_target_id()].finish();
})); }));
return result; return result;
})); });
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -861,7 +1006,9 @@ void program::print_py(std::ostream& os) const ...@@ -861,7 +1006,9 @@ void program::print_py(std::ostream& os) const
os << "p = migraphx.program()\n"; os << "p = migraphx.program()\n";
for(auto& mod : vec_modules) for(auto& mod : vec_modules)
{ {
std::string var_name = "m" + mod->name(); std::string var_name = "m";
if(mod->name() != "main")
var_name += mod->name();
os << var_name << " = "; os << var_name << " = ";
if(mod->name() == "main") if(mod->name() == "main")
os << "p.get_main_module()"; os << "p.get_main_module()";
...@@ -894,10 +1041,10 @@ void program::print_cpp(std::ostream& os) const ...@@ -894,10 +1041,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) { generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
...@@ -1045,11 +1192,19 @@ void program::remove_unused_modules() ...@@ -1045,11 +1192,19 @@ void program::remove_unused_modules()
program& program::sort() program& program::sort()
{ {
for(auto& pp : this->impl->modules) std::queue<migraphx::module_ref> mqueue;
mqueue.push(get_main_module());
while(not mqueue.empty())
{ {
pp.second.sort(); module_ref current_mod = mqueue.front();
current_mod->sort();
mqueue.pop();
auto child_mods = current_mod->get_sub_modules(true);
for(auto& sub_mod : child_mods)
{
mqueue.push(sub_mod);
}
} }
return *this; return *this;
} }
......
...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const ...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module(); module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module(); module_ref root_module = mpm.get_root_module();
if(m.name() == "main") if(m == *root_module)
return; return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -27,11 +27,14 @@ ...@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const ...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{ {
if(not literals[i].empty()) if(not literals[i].empty())
{ {
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l); m.replace_instruction(const_instrs_vec[i], l);
......
...@@ -23,14 +23,24 @@ ...@@ -23,14 +23,24 @@
##################################################################################### #####################################################################################
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
include(PythonModules) include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION})
add_library(migraphx_py_${PYTHON_VERSION} py.cpp)
target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach() endforeach()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment