Commit d6b4ae77 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge optimization to print flops branch

parents bdf91961 abe2a889
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -181,14 +181,15 @@ struct nonmaxsuppression
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
......
......@@ -27,7 +27,7 @@ struct reshape
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"std_shape", true}}; }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -253,7 +253,6 @@ struct roialign
max_pool{});
output(n, c, ph, pw) = output_val;
});
});
});
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -64,7 +64,6 @@ struct unary : op_name<Derived>
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
return result;
......
......@@ -41,6 +41,7 @@
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
......@@ -86,6 +87,9 @@
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......
......@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{
dfor(xs...)(f);
}
};
}
......
......@@ -80,6 +80,8 @@ struct program
void debug_print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const;
void debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& ins_names) const;
void print(std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
......
......@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
......@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -186,8 +188,7 @@ struct shape
{
switch(t)
{
case tuple_type:
{
case tuple_type: {
tv();
return;
}
......@@ -226,6 +227,7 @@ struct shape
std::size_t element_space() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
};
......
......@@ -131,16 +131,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${",
std::string end = "}")
{
return interpolate_string(input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
return interpolate_string(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Iterator>
......
......@@ -315,8 +315,7 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
if(this->key.empty())
v(null);
......@@ -325,8 +324,7 @@ struct value
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
if(this->key.empty()) \
v(this->get_##vt()); \
else \
......@@ -346,15 +344,13 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
v(this->get_##vt()); \
return; \
}
......
......@@ -628,7 +628,7 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
}
// make instruction index to be the line num in the printed module
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name);
......
......@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
switch(o.type)
{
case msgpack::type::NIL:
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN:
{
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER:
{
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER:
{
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64:
{
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR:
{
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN:
{
case msgpack::type::BIN: {
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY:
{
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
......@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::MAP:
{
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
......@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported.");
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
......
......@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
......@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{
return ins;
return add_instruction(make_op("contiguous"), ins);
}
return add_instruction(make_op("contiguous"), ins);
return ins;
}
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
......@@ -380,8 +382,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16:
{
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
......@@ -451,7 +452,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
......
......@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Floor", "floor"},
{"Gather", "gather"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"Neg", "neg"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment