"src/targets/vscode:/vscode.git/clone" did not exist on "2152df829a9776ee486877545197b4eeb677e442"
Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
......@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
......@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -186,8 +188,7 @@ struct shape
{
switch(t)
{
case tuple_type:
{
case tuple_type: {
tv();
return;
}
......@@ -224,10 +225,11 @@ struct shape
const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
};
void migraphx_to_value(value& v, const shape& s);
......
......@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -36,20 +36,26 @@ struct stream_model
#else
/*
* Type-erased interface for:
*
* struct stream_model
* {
* std::size_t get_nstream() const;
* std::size_t get_stream(instruction_ref ins) const;
* std::size_t get_event_id(instruction_ref ins) const;
* bool has_stream(instruction_ref ins) const;
* bool is_record(instruction_ref ins) const;
* bool is_wait(instruction_ref ins) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct stream_model
{
//
std::size_t get_nstream() const;
//
std::size_t get_stream(instruction_ref ins) const;
//
std::size_t get_event_id(instruction_ref ins) const;
//
bool has_stream(instruction_ref ins) const;
//
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model
{
......@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -137,16 +137,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${",
std::string end = "}")
{
return interpolate_string(input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
return interpolate_string(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Iterator>
......
......@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg;
}
/*
* Type-erased interface for:
*
* struct target
* {
* std::string name() const;
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const;
* context get_context() const;
* argument copy_to(const argument& input) const;
* argument copy_from(const argument& input) const;
* argument allocate(const shape& s) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct target
{
//
std::string name() const;
//
std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
//
context get_context() const;
// (optional)
argument copy_to(const argument& input) const;
// (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target
{
......@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -120,10 +120,8 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)];
}
// cppcheck-suppress functionConst
iterator begin() { return {0, {this}}; }
// cppcheck-suppress functionConst
iterator end() { return {this->size(), {this}}; }
const_iterator begin() const { return {0, {this}}; }
......
......@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t);
value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \
......@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T>
using pick_numeric = std::conditional_t<
std::is_floating_point<T>{},
......@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT
}
value& operator=(const char* c);
value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
......@@ -315,8 +323,7 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
if(this->key.empty())
v(null);
......@@ -325,8 +332,7 @@ struct value
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
if(this->key.empty()) \
v(this->get_##vt()); \
else \
......@@ -346,19 +352,17 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
......@@ -374,11 +378,11 @@ struct value
}
template <class To>
To value_or(const To& default_value) const
literal_to_string<To> value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<To>();
return to<literal_to_string<To>>();
}
template <class To>
......@@ -394,12 +398,12 @@ struct value
}
template <class To>
To get(const std::string& pkey, const To& default_value) const
literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to<To>();
return v->to<literal_to_string<To>>();
}
template <class To>
......@@ -412,10 +416,11 @@ struct value
}
template <class To>
std::vector<To> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
{
return get<std::vector<To>>(pkey, default_value);
return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
}
friend bool operator==(const value& x, const value& y);
......@@ -429,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const;
type_t get_type() const;
private:
template <class T>
std::vector<value> from_values(const T& r)
......@@ -438,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const;
std::shared_ptr<value_base_impl> x;
std::string key;
};
......
......@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr)
*out_error = error;
return error <= threshold;
......
......@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
return;
}
......
......@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump();
}
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size)
{
json j = json::parse(str, str + size);
......
......@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
for_each([&](const auto& key, const auto& x) {
if(not w.contains(key))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w);
return op;
}
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -23,6 +23,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl
{
// A list is used to keep references to an instruction stable
......@@ -554,8 +556,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx)
{
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this))
{
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx);
for(const auto& smod : ins->module_inputs())
{
......@@ -628,8 +636,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
count++;
}
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name);
print_func(ins, names);
......@@ -719,7 +728,6 @@ module::print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
......
......@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
switch(o.type)
{
case msgpack::type::NIL:
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN:
{
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER:
{
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER:
{
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64:
{
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR:
{
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN:
{
case msgpack::type::BIN: {
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY:
{
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
......@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::MAP:
{
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
......@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported.");
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
......
......@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
......@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const;
......
......@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{
return ins;
return add_instruction(make_op("contiguous"), ins);
}
return add_instruction(make_op("contiguous"), ins);
return ins;
}
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
......@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0,
instruction_ref arg1) const
{
return add_common_op(*mod, make_op(op_name), {arg0, arg1});
return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
}
instruction_ref
......@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16:
{
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
......@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{
return;
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
......@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true;
}
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
if(min_used and max_used)
{
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
return info.add_common_op("clip", args[0], min_arg, max_arg);
}
else if(max_used)
{
return info.add_instruction(make_op("min"), args[0], max_arg);
return info.add_broadcastable_binary_op("min", args[0], max_arg);
}
else if(min_used)
{
return info.add_instruction(make_op("max"), args[0], min_arg);
return info.add_broadcastable_binary_op("max", args[0], min_arg);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment