Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
......@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
......
......@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
return info.add_instruction(
......
......@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
}
};
......
......@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
l_stride =
info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
......@@ -55,13 +55,17 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
if(args.size() == 3)
{
if(beta != 0.0f && args[2]->get_shape().elements() > 0)
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
......@@ -69,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
......@@ -80,12 +84,11 @@ struct parse_gemm : op_parser<parse_gemm>
beta_l3);
}
return info.add_instruction(
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
return info.add_instruction(make_op("add"), ret, beta_l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
return ret;
}
};
......
......@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
......
......@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
......
......@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto tr_out =
info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
......@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
......
......@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
......
......@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
}
auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
if(perm.empty())
{
perm.resize(n_dim);
std::iota(perm.rbegin(), perm.rend(), 0);
}
if(perm.size() != n_dim)
{
MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!");
}
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
}
};
......
......@@ -17,45 +17,28 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto cond =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond);
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
// compute index
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
};
......
......@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes();
if(size == 0)
return false;
std::size_t element_size = size / s.elements();
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
......
......@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace)
trace();
#endif
}
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
......@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace)
trace(prog);
}
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
run_pass(mod, p, trace);
module_pm{&mod, nullptr, &trace}.run_pass(p);
}
}
......@@ -62,7 +95,7 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{
run_pass(*mod, p, trace);
module_pm{mod, &prog, &trace}.run_pass(p);
}
run_pass(prog, p, trace);
}
......
......@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void preallocate_param::apply(module& m) const
{
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "@param")
......@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const
std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
}
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -184,14 +184,16 @@ std::vector<argument> generic_eval(const module* mod,
context& ctx,
std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results,
F trace)
F make_trace)
{
assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2);
std::vector<argument> values;
values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod))
{
assert(results.find(ins) == results.end());
const auto& name = ins->name();
if(name == "@literal")
{
......@@ -240,7 +242,8 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
return generic_eval(smod, ctx, inputs, results, trace);
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
......@@ -249,6 +252,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
}
return {results.at(std::prev(mod->end()))};
}
......@@ -257,50 +261,67 @@ template <class F>
std::vector<argument> generic_eval(const program& p,
context& ctx,
std::unordered_map<std::string, argument> params,
F trace)
F make_trace)
{
const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, trace);
return generic_eval(mm, ctx, params, {}, make_trace);
}
std::vector<argument> program::eval(parameter_map params) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto check_context = [](auto f) { return f(); };
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0)
{
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load")
std::cout << "Output: " << result << std::endl;
return result;
});
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load")
std::cout << "Output: " << result << std::endl;
return result;
}));
}
else
{
return generic_eval(
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
}
}
......@@ -502,21 +523,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map
generic_eval(*this, ctx, params, [&](auto ins, auto) {
generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{};
});
return argument{ins->get_shape(), nullptr};
}));
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result;
ins_vec[ins].push_back(time<milliseconds>([&] {
result = f();
ctx.finish();
}));
return result;
});
}));
}
for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end());
......@@ -645,7 +667,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
}
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......@@ -745,6 +769,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment