Commit 3df20646 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 1005a693 d0543c96
......@@ -33,9 +33,6 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
// Skip convert op for now
if(ins->name() == "convert")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
......@@ -129,22 +126,25 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
static bool find_pointwise_modules(module& m)
{
bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty())
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto input = *it;
auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true;
}
return changed;
......
......@@ -34,6 +34,7 @@ struct cpp_generator
std::string return_type = "void";
std::string name = "";
std::vector<std::string> attributes = {};
std::vector<std::string> tparams = {};
function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s)
{
......@@ -52,6 +53,7 @@ struct cpp_generator
}
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
};
cpp_generator();
......@@ -66,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const;
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -35,7 +35,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -35,7 +35,7 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -7,7 +7,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -21,25 +21,26 @@ struct clip
{
std::string name() const { return "clip"; }
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_type();
check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result;
}
};
......
......@@ -32,6 +32,11 @@ struct convert : unary<convert>
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
std::string point_op() const
{
return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})";
}
auto apply() const
{
auto type = target_type;
......
......@@ -38,7 +38,7 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
auto type = inputs[0].type();
lens.erase(lens.begin() + axis);
......
......@@ -168,7 +168,8 @@ inline std::string to_string_range(const std::initializer_list<T>& r)
}
template <class T>
inline std::string to_string(const T& x)
inline auto to_string(const T& x)
-> decltype((std::declval<std::stringstream>() << x), std::string{})
{
std::stringstream ss;
ss << x;
......
File mode changed from 100755 to 100644
......@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)});
......@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
......@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
......@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
......@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_greaterorequal : op_parser<parse_greaterorequal>
{
std::vector<op_desc> operators() const { return {{"GreaterOrEqual"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]);
if(in_res->get_shape().type() != shape::bool_type)
{
in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}),
in_res);
}
return info.add_instruction(make_op("not"), in_res);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{
std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 0.2;
float beta = 0.5;
if(opd.onnx_name == "HardSwish")
{
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial>
if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
......@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial>
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution
std::mt19937 gen(seed);
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
......
......@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
......@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(seed);
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
......@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "low"))
low = info.attributes.at("low").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
......@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(seed);
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softplus : op_parser<parse_softplus>
{
std::vector<op_desc> operators() const { return {{"Softplus"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = ln(exp(x) + 1)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto exp = info.add_instruction(migraphx::make_op("exp"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), exp, mb_ones);
return info.add_instruction(migraphx::make_op("log"), add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softsign : op_parser<parse_softsign>
{
std::vector<op_desc> operators() const { return {{"Softsign"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = x / (1 + |x|)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto abs = info.add_instruction(migraphx::make_op("abs"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), abs, mb_ones);
return info.add_instruction(migraphx::make_op("div"), args[0], add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment