Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
......@@ -36,20 +36,26 @@ struct stream_model
#else
/*
* Type-erased interface for:
*
* struct stream_model
* {
* std::size_t get_nstream() const;
* std::size_t get_stream(instruction_ref ins) const;
* std::size_t get_event_id(instruction_ref ins) const;
* bool has_stream(instruction_ref ins) const;
* bool is_record(instruction_ref ins) const;
* bool is_wait(instruction_ref ins) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct stream_model
{
//
std::size_t get_nstream() const;
//
std::size_t get_stream(instruction_ref ins) const;
//
std::size_t get_event_id(instruction_ref ins) const;
//
bool has_stream(instruction_ref ins) const;
//
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model
{
......@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -131,16 +131,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${",
std::string end = "}")
{
return interpolate_string(input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
return interpolate_string(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Iterator>
......
......@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg;
}
/*
* Type-erased interface for:
*
* struct target
* {
* std::string name() const;
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const;
* context get_context() const;
* argument copy_to(const argument& input) const;
* argument copy_from(const argument& input) const;
* argument allocate(const shape& s) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct target
{
//
std::string name() const;
//
std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
//
context get_context() const;
// (optional)
argument copy_to(const argument& input) const;
// (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target
{
......@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t);
value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \
......@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T>
using pick_numeric = std::conditional_t<
std::is_floating_point<T>{},
......@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT
}
value& operator=(const char* c);
value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
......@@ -315,8 +323,7 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
if(this->key.empty())
v(null);
......@@ -325,8 +332,7 @@ struct value
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
if(this->key.empty()) \
v(this->get_##vt()); \
else \
......@@ -346,19 +352,17 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
......@@ -374,11 +378,11 @@ struct value
}
template <class To>
To value_or(const To& default_value) const
literal_to_string<To> value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<To>();
return to<literal_to_string<To>>();
}
template <class To>
......@@ -394,12 +398,12 @@ struct value
}
template <class To>
To get(const std::string& pkey, const To& default_value) const
literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to<To>();
return v->to<literal_to_string<To>>();
}
template <class To>
......@@ -412,10 +416,11 @@ struct value
}
template <class To>
std::vector<To> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
{
return get<std::vector<To>>(pkey, default_value);
return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
}
friend bool operator==(const value& x, const value& y);
......@@ -429,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const;
type_t get_type() const;
private:
template <class T>
std::vector<value> from_values(const T& r)
......@@ -438,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const;
std::shared_ptr<value_base_impl> x;
std::string key;
};
......
......@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
return;
}
......
......@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump();
}
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size)
{
json j = json::parse(str, str + size);
......
......@@ -629,8 +629,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
count++;
}
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name);
print_func(ins, names);
......
......@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
switch(o.type)
{
case msgpack::type::NIL:
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN:
{
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER:
{
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER:
{
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64:
{
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR:
{
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN:
{
case msgpack::type::BIN: {
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY:
{
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
......@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::MAP:
{
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
......@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported.");
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
......
......@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
......@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const;
......
......@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{
return ins;
return add_instruction(make_op("contiguous"), ins);
}
return add_instruction(make_op("contiguous"), ins);
return ins;
}
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
......@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0,
instruction_ref arg1) const
{
return add_common_op(*mod, make_op(op_name), {arg0, arg1});
return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
}
instruction_ref
......@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16:
{
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
......@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{
return;
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Floor", "floor"},
{"Gather", "gather"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"LRN", "lrn"},
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment