Unverified Commit 23cb7917 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into blas_tuning

parents b5fcc0bc ea32ca70
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for a compilation target /// An interface for a compilation target
...@@ -125,7 +127,7 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric) ...@@ -125,7 +127,7 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct target struct MIGRAPHX_EXPORT target
{ {
// //
std::string name() const; std::string name() const;
...@@ -165,7 +167,7 @@ struct target ...@@ -165,7 +167,7 @@ struct target
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -426,7 +428,7 @@ struct target ...@@ -426,7 +428,7 @@ struct target
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x) ...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tf/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +42,10 @@ struct tf_options ...@@ -41,7 +42,10 @@ struct tf_options
}; };
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, const tf_options& options = tf_options{}); MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name,
const tf_options& options = tf_options{});
MIGRAPHX_TF_EXPORT std::vector<std::string> get_tf_operators();
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct tmp_dir struct MIGRAPHX_EXPORT tmp_dir
{ {
fs::path path; fs::path path;
tmp_dir(const std::string& prefix = ""); tmp_dir(const std::string& prefix = "");
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <cstdint>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <tuple> #include <tuple>
...@@ -140,7 +141,7 @@ To try_convert_value(const From& x) ...@@ -140,7 +141,7 @@ To try_convert_value(const From& x)
return detail::try_convert_value_impl<To>(rank<3>{}, x); return detail::try_convert_value_impl<To>(rank<3>{}, x);
} }
struct value struct MIGRAPHX_EXPORT value
{ {
// clang-format off // clang-format off
#define MIGRAPHX_VISIT_VALUE_TYPES(m) \ #define MIGRAPHX_VISIT_VALUE_TYPES(m) \
...@@ -392,8 +393,8 @@ struct value ...@@ -392,8 +393,8 @@ struct value
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
} }
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
...@@ -452,14 +453,16 @@ struct value ...@@ -452,14 +453,16 @@ struct value
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()}); std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator==(const value& x, const value& y);
friend bool operator!=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator!=(const value& x, const value& y);
friend bool operator<(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator<(const value& x, const value& y);
friend bool operator<=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator<=(const value& x, const value& y);
friend bool operator>(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator>(const value& x, const value& y);
friend bool operator>=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator>=(const value& x, const value& y);
friend std::ostream& operator<<(std::ostream& os, const value& d); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
...@@ -481,4 +484,15 @@ struct value ...@@ -481,4 +484,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const { return x.hash(); }
};
} // namespace std
#endif #endif
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return error <= threshold; return error <= threshold;
} }
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg, const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
......
...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r) ...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return") assert(ins->name() == "@return" or ins->name().front() != '@');
continue;
assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
} }
...@@ -122,10 +119,6 @@ bool instruction::valid() const ...@@ -122,10 +119,6 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@return")
{
computed = {};
}
else else
{ {
try try
...@@ -145,6 +138,7 @@ bool instruction::valid() const ...@@ -145,6 +138,7 @@ bool instruction::valid() const
} }
shape instruction::get_shape() const { return result; } shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const const literal& instruction::get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
...@@ -395,7 +389,7 @@ void instruction::print(std::ostream& os, ...@@ -395,7 +389,7 @@ void instruction::print(std::ostream& os,
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
std::string delim = ", ["; std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs()) for(const const_module_ref& mod_arg : ins->module_inputs())
{ {
os << delim << mod_arg->name(); os << delim << mod_arg->name();
delim = ", "; delim = ", ";
...@@ -406,6 +400,9 @@ void instruction::print(std::ostream& os, ...@@ -406,6 +400,9 @@ void instruction::print(std::ostream& os,
// skip return instruction shape // skip return instruction shape
if(ins->name() != "@return") if(ins->name() != "@return")
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
// print tid
os << ", target_id=" << ins->target_id;
} }
static void debug_name(std::ostream& os, const instruction& ins) static void debug_name(std::ostream& os, const instruction& ins)
...@@ -464,11 +461,14 @@ operation instruction::normalized_operator() const ...@@ -464,11 +461,14 @@ operation instruction::normalized_operator() const
if(this->need_normalization()) if(this->need_normalization())
{ {
auto s = this->inputs().front()->get_shape(); auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s.max_lens())) if(not normalize_attributes(o, s))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
} }
std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args) std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{ {
......
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
...@@ -458,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -458,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->push_back({builtin::returns{}, {}, std::move(args)}); shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const ...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
void module::finalize(context& ctx) void module::finalize(std::vector<context>& contexts)
{ {
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
...@@ -660,10 +663,10 @@ void module::finalize(context& ctx) ...@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: "; std::cout << "Finalize: ";
this->debug_print(ins); this->debug_print(ins);
} }
ins->finalize(ctx); ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
smod->finalize(ctx); smod->finalize(contexts);
} }
} }
...@@ -723,15 +726,15 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -723,15 +726,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
std::string var_name; std::string var_name;
if(not this->name().empty() and this->name() != "main")
var_name = this->name() + ":";
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->get_operator()).parameter; var_name.append(any_cast<builtin::param>(ins->get_operator()).parameter);
} }
else else
{ {
var_name = this->name(); var_name.append("@" + std::to_string(count));
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
} }
// count every instruction so index matches loc in the printout program // count every instruction so index matches loc in the printout program
count++; count++;
...@@ -795,7 +798,10 @@ static std::string to_c_id(const std::string& name, char rep = '_') ...@@ -795,7 +798,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static std::string cpp_var_name(const std::string& name) static std::string cpp_var_name(const std::string& name)
{ {
return to_c_id("x_" + replace_string(name, ":", "_module_")); std::string prefix = "x_";
if(not contains(name, "@"))
prefix = "p_";
return to_c_id(prefix + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op) static void print_py_op(std::ostream& os, const operation& op)
...@@ -867,15 +873,14 @@ module::print_py(std::ostream& os, ...@@ -867,15 +873,14 @@ module::print_py(std::ostream& os,
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
os << mname << ".add_literal("; os << mname << ".add_literal(";
bool use_abs = false; const bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now // Disable abs for now
use_abs = false; // ins->get_literal().visit([&](auto v) {
// use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
// });
if(use_abs) if(use_abs)
os << "migraphx.abs_literal("; os << "migraphx.abs_literal(";
os << "migraphx.generate_literal("; os << "migraphx.generate_argument(";
print_py_shape(os, ins->get_shape()); print_py_shape(os, ins->get_shape());
os << ", " << seed << ")"; os << ", " << seed << ")";
if(use_abs) if(use_abs)
...@@ -1005,9 +1010,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const ...@@ -1005,9 +1010,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
module& module::sort() module& module::sort()
{ {
auto implicit_deps = calc_implicit_deps();
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) auto ins_inputs = ins->inputs();
if(implicit_deps.find(ins) != implicit_deps.end())
{
auto ins_implict_inputs = implicit_deps.at(ins);
ins_inputs.insert(
ins_inputs.end(), ins_implict_inputs.begin(), ins_implict_inputs.end());
}
for(auto child : ins_inputs)
{ {
if(not contains(this->impl->instructions, child)) if(not contains(this->impl->instructions, child))
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,18 +35,21 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,18 +35,21 @@ inline namespace MIGRAPHX_INLINE_NS {
* vec: the vector attribute to normalize * vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise * axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] = * val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling * value::array{normalize_attribute::include_min};
* normalize_attributes(op&, lens) * input_shape: input shape passed when calling
* normalize_attributes(op&, input_shape)
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const shape& input_shape,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
{ {
...@@ -54,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -54,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec,
} }
std::vector<int64_t> max_vals(vec.size(), n_rank); std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len)) if(contains(vec_attrs, op::normalize_attribute::use_len))
{ {
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; }); if(input_shape.dynamic())
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
});
}
else
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.lens().at(i);
});
}
} }
if(contains(vec_attrs, op::normalize_attribute::clip_max)) if(contains(vec_attrs, op::normalize_attribute::clip_max))
...@@ -84,14 +106,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -84,14 +106,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -124,14 +146,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -124,14 +146,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -157,9 +179,9 @@ auto tune_pad_attribute(const value& val) ...@@ -157,9 +179,9 @@ auto tune_pad_attribute(const value& val)
/** /**
* Assumptions: * Assumptions:
* Dimensions to pad start from the third dimension (index 2). * Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input. * Called by compute_shape_op() with the shape of the first input.
*/ */
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const shape& input_shape)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
...@@ -170,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -170,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto padding_size = padding.size(); auto padding_size = padding.size();
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true; tuned = true;
else if(padding_size != (lens.size() - padding_start)) else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); MIGRAPHX_THROW("inconsistent padding size");
else else
{ {
...@@ -193,7 +215,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -193,7 +215,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -202,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -202,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -211,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -211,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const ...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const
auto s = inputs[0]->get_shape(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, s.max_lens())) if(normalize_attributes(tuned_op, s))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -30,10 +30,11 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -30,10 +30,11 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB ONNX_SRCS CONFIGURE_DEPENDS *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS}) add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include) target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
......
...@@ -117,6 +117,7 @@ struct onnx_parser ...@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false); parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
}; };
......
...@@ -38,6 +38,9 @@ ...@@ -38,6 +38,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER)
static shape shape_from_dyn_dims(shape::type_t shape_type, static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims) const std::vector<shape::dynamic_dimension>& dyn_dims)
...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, ...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
return {shape_type, dyn_dims}; return {shape_type, dyn_dims};
} }
namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
...@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return this->add_common_op(op_name, arg0, arg1); return this->add_common_op(op_name, arg0, arg1);
} }
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const std::vector<instruction_ref> inputs) const
{ {
...@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
std::vector<instruction_ref> void print_added_instructions(module* mod,
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& result)
{
// Print instructions added by the parser not in args
std::vector<instruction_ref> added_instructions;
fix([&](auto self, auto r) {
for(auto ins : r)
{
if(contains(args, ins))
continue;
if(contains(added_instructions, ins))
continue;
self(ins->inputs());
added_instructions.push_back(ins);
}
})(result);
mod->debug_print(added_instructions);
}
std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod // backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
} }
return mod_insts;
}
std::unordered_map<std::string, instruction_ref>
parse_inputs(const onnx_parser& parser,
module* mod,
const onnx::GraphProto& graph,
std::unordered_map<std::string, instruction_ref> mod_insts)
{
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
...@@ -298,36 +350,48 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -298,36 +350,48 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name)) if(contains(parser.instructions, name))
{ {
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!"); "\" existing in parent graph!");
} }
shape s; shape s;
std::vector<std::size_t> dims; if(parser.map_input_dims.count(name) > 0)
if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); std::vector<std::size_t> dims = parser.map_input_dims.at(name);
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name)); s = shape_from_dyn_dims(shape_type, parser.map_dyn_input_dims.at(name));
} }
else else
{ {
s = parse_type(input.type(), dims); s = parser.parse_type(input.type());
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
return mod_insts;
}
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts =
parse_intializer(*this, mod, graph);
mod_insts = parse_inputs(*this, mod, graph, mod_insts);
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "operator: " << node.op_type() << std::endl;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
...@@ -365,6 +429,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -365,6 +429,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result.begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
{
print_added_instructions(mod, args, result);
}
} }
// Find instructions corresponding to the output // Find instructions corresponding to the output
...@@ -483,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -483,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
} }
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t, shape onnx_parser::parse_type(const onnx::TypeProto& t) const
const std::vector<std::size_t>& input_dims) const
{ {
shape::type_t shape_type = get_type(t.tensor_type().elem_type()); shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(not input_dims.empty())
{
return {shape_type, input_dims};
}
std::vector<shape::dynamic_dimension> dynamic_dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
...@@ -520,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -520,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return shape_from_dyn_dims(shape_type, dynamic_dims); return shape_from_dyn_dims(shape_type, dynamic_dims);
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(input_dims.empty())
return {shape_type};
return {shape_type, input_dims};
}
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
{ {
switch(dtype) switch(dtype)
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size(); auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2) if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_rank > 2) else if(x_rank > 2)
...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction( auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction( auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
else else
......
...@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val = literal({shape::float_type, {1}, {0}}, {0.0f}); l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
} }
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty()) if(args.empty())
{ {
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!"); MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
...@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
else else
{ {
migraphx::shape s; migraphx::shape s;
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
// empty input tensor, output is a scalar // empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0) if(args[0]->get_shape().elements() == 0)
{ {
......
...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector) ...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
return output_vector; return output_vector;
} }
struct parse_deconvolution : op_parser<parse_deconvolution> struct parse_conv_transpose : op_parser<parse_conv_transpose>
{ {
std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; } std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; }
...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
operation op = make_op("deconvolution"); operation op = make_op("convolution_backwards");
value values = op.to_value(); value values = op.to_value();
// op::deconvolution op; auto l0 = args[0];
auto l0 = args[0];
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
bool asym_padding = false; bool asym_padding = false;
auto in_lens = l0->get_shape().lens(); assert(l0->get_shape().ndim() > 2);
assert(in_lens.size() > 2); auto kdims = l0->get_shape().ndim() - 2;
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV_TRANSPOSE"); check_padding_mode(info, "CONV_TRANSPOSE");
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
asym_padding = is_asym_padding(padding); asym_padding = is_asym_padding(padding);
size_t pad_ndims = padding.size() / 2;
if(not asym_padding) if(not asym_padding)
{ {
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings"); check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
values["padding"].clear(); values["padding"].clear();
std::transform(padding.begin(), std::transform(padding.begin(),
...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
std::back_inserter(values["padding"]), std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; }); [](auto pad_val) { return pad_val; });
} }
else if(l0->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: asymmetric padding (padding_L != padding_R) "
"not supported with dynamic shapes");
}
else
{
// set padding to 0s, asym_padding handled by parser with slice
// TODO changing parser and op to do asym padding in op
values["padding"] = std::vector<std::size_t>(pad_ndims, 0);
}
} }
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes( check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides"); kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
values["dilation"].clear(); values["dilation"].clear();
...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
} }
// TODO: auto padding needs to be implemented for this parser and operator // TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes.at("auto_pad").s()) != "NOTSET")
{ {
auto s = info.attributes["auto_pad"].s(); MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto padding not supported");
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
"simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
bool is_same_upper = (s.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
} }
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
recalc_conv_attributes(values, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]); auto l1 = info.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding) if(asym_padding)
{ {
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims
...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1); make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
} }
if(contains(info.attributes, "output_padding")) // TODO, should check output_padding < (strides or dilations)
if(contains(info.attributes, "output_padding") and
not contains(info.attributes, "output_shape"))
{ {
size_t non_kdims = dims.size() * 2 - kdims; size_t non_kdims = l1->get_shape().ndim() * 2 - kdims;
std::vector<int64_t> output_padding(non_kdims, 0); std::vector<int64_t> output_padding(non_kdims, 0);
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding)); copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
check_attr_sizes(kdims, check_attr_sizes(kdims,
...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1); l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
} }
// TODO, doing unnecessary calcuations with this. Could instead
// calculate the padding to conv_transpose that would give the output_shape.
if(contains(info.attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
{ {
if(l1->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: output_shape attribute and dynamic shapes "
"not supported");
}
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape)); copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
check_attr_sizes( check_attr_sizes(
kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape"); kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
dims = to_int64_vector(l1->get_shape().lens());
copy(dims.begin() + 2, dims.end(), curr_shape.begin());
if(curr_shape != output_shape) if(curr_shape != output_shape)
{ {
std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0); std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,10 +21,14 @@ ...@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -32,62 +36,117 @@ namespace onnx { ...@@ -32,62 +36,117 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); // Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); // inputs may be either static or dynamic.
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto l1 = info.add_common_op("sub", x, mean);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); // for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}}); // reduce_sum to calculate variance i.e.
auto epsilon_bcast = // var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); std::string reduce_op_name =
auto variance_bcast = (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance); if(dtype == shape::half_type and not convert_fp16)
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); {
if(x->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_INSTANCENORM: half type not supported with dynamic shape "
"unless convert_fp16 is TRUE");
}
auto dims = x->get_shape().lens();
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_common_op("sqdiff", x, mean);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto l2 = info.add_common_op("add", variance, epsilon_literal);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_common_op("mul", l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); // add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
; // both scale and bias.
auto bias_bcast = instruction_ref scale_bcast;
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); instruction_ref bias_bcast;
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); if(x->get_shape().dynamic())
return info.add_instruction(make_op("add"), l5, bias_bcast); {
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
auto dims = x->get_shape().lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
}
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -33,8 +33,7 @@ namespace onnx { ...@@ -33,8 +33,7 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = { std::set<shape::type_t> float_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
......
...@@ -35,8 +35,7 @@ namespace onnx { ...@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{ {
const std::set<shape::type_t> valid_types = { std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; } std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment