Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
......@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(
......@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias);
brb_zr = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
......@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
}
......@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
......
......@@ -2,10 +2,12 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
#include <unordered_map>
#include <unordered_set>
#include <queue>
......@@ -16,6 +18,7 @@
#include <set>
#include <deque>
#include <chrono>
#include <iomanip>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -88,7 +91,7 @@ struct stream_info
return args.end();
}
const std::size_t min_partition_threshold = 1;
const std::size_t min_partition_threshold = 2;
sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()),
......@@ -120,11 +123,11 @@ struct stream_info
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end());
if(contains(partitions, ins))
return;
assert(not is_end(ins, p.end()));
if(not p.has_instruction(ins))
return;
if(contains(partitions, ins))
return;
// Add an entry so we know the instruction was visited
partitions[ins];
......@@ -236,6 +239,18 @@ struct stream_info
}
}
}
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != p.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
p.move_instruction(ins, p.begin());
}
ins = next;
}
}
void set_stream(const partition& p, std::size_t n)
......@@ -353,6 +368,7 @@ struct stream_info
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p);
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
......@@ -366,8 +382,13 @@ struct stream_info
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
}
auto streams = this->get_streams(ins);
if(is_split_point(ins))
{
erase_if(merge_from[ins],
[&](auto merge) { return di.strictly_dominate(ins, merge); });
}
auto streams = this->get_streams(ins);
// Collect concur instructions for each merge point.
for(const auto& merge : merge_from[ins])
{
......@@ -396,11 +417,18 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(p))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins;
......@@ -442,14 +470,13 @@ struct stream_info
for(auto ins1 : ins1_set)
{
auto p1 = std::distance(ins1, merge_first);
auto p1 = ins2index.at(ins1);
for(auto ins2 : ins2_set)
{
if(ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge_first);
// The smaller distance means the instruction occurs later
if(p1 > p2)
auto p2 = ins2index.at(ins2);
if(p2 > p1)
thrd_table[ins2].insert(ins1);
else
thrd_table[ins1].insert(ins2);
......@@ -495,6 +522,9 @@ void schedule::apply(module& p) const
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
p.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={";
......@@ -535,11 +565,9 @@ void schedule::apply(module& p) const
{
for(auto i : si.get_recorded_instructions(ins))
{
if(not si.has_stream(i))
continue;
auto istream = si.get_stream(i);
if(stream == istream)
if(not si.has_stream(i) or si.get_stream(i) == stream)
continue;
// Create a new event if it hasn't been recorded
if(not contains(ins2wait, i))
{
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
......@@ -20,6 +13,7 @@
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -61,7 +55,7 @@ struct find_mul_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
......@@ -126,7 +120,7 @@ struct find_mul_slice_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
......@@ -155,7 +149,8 @@ struct find_mul_slice_conv
assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
for(auto output : conv_ins->outputs())
auto outputs = conv_ins->outputs();
for(auto output : outputs)
if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv);
}
......@@ -403,8 +398,27 @@ struct find_splits
match::any_of[match::outputs()](match::pointwise(), reduction()))));
}
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
{
std::unordered_set<instruction_ref> traversed;
return fix<bool>([&](auto self, auto ins) -> bool {
if(ins == ins2)
return true;
if(contains(traversed, ins))
return false;
traversed.insert(ins);
const auto& inputs = ins->inputs();
return std::any_of(inputs.begin(), inputs.end(), [&](auto in) {
return m.has_instruction(in) and self(in);
});
})(ins1);
}
static std::vector<std::vector<instruction_ref>>
get_split_groups(const std::vector<instruction_ref>& splits)
get_split_groups(const module& m, const std::vector<instruction_ref>& splits)
{
std::vector<std::vector<instruction_ref>> groups;
for(auto out : splits.front()->outputs())
......@@ -421,9 +435,16 @@ struct find_splits
if(it == split->outputs().end())
break;
assert((*it)->name() != "slice");
// If there is a duplicate bail
if(contains(group, *it))
// there are should be no dependency between instructions in the group
if(std::any_of(group.begin(), group.end(), [&](auto i) {
return is_dependent(m, *it, i) or is_dependent(m, i, *it);
}))
{
return {};
}
group.push_back(*it);
}
if(group.size() != splits.size())
......@@ -460,13 +481,12 @@ struct find_splits
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(splits))
for(const auto& group : get_split_groups(p, splits))
{
auto start = group.front();
auto split_front = splits.front();
......@@ -535,7 +555,8 @@ struct find_splits
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
for(auto output : i->outputs())
auto outputs = i->outputs();
for(auto output : outputs)
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
......@@ -644,19 +665,6 @@ struct find_add_convs
return x.stride[0] / y.stride[0];
}
static shape compute_stride_shape(const shape& input, std::size_t n)
{
return {input.type(),
{input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(1, (input.lens()[2] - 1) / n + 1)),
std::size_t(std::max<std::ptrdiff_t>(1, (input.lens()[3] - 1) / n + 1))},
{input.strides()[0],
input.strides()[1],
input.strides()[2] * n,
input.strides()[3] * n}};
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
......@@ -687,11 +695,7 @@ struct find_add_convs
return;
new_op = a_op;
b_input = p.insert_instruction(
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(b_input->get_shape(), n))}}),
b_input);
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
}
else if(b_op.stride < a_op.stride)
{
......@@ -700,11 +704,7 @@ struct find_add_convs
return;
new_op = b_op;
a_input = p.insert_instruction(
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(a_input->get_shape(), n))}}),
a_input);
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
}
else
return;
......@@ -989,8 +989,8 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
auto tr = p.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
struct match_find_quantizable_ops
{
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
}
void apply(module& m, match::matcher_result r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq);
}
};
bool compare_literals(instruction_ref ins1, instruction_ref ins2)
{
if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast")
ins1 = ins1->inputs().front();
auto x = ins1->eval();
if(x.empty())
return false;
auto literal1 = ins1->get_literal();
if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast")
ins2 = ins2->inputs().front();
auto y = ins2->eval();
if(y.empty())
return false;
auto literal2 = ins2->get_literal();
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); });
});
return (x == y) or diff_shapes_equal_vals;
}
void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
{
auto args = ins->inputs();
for(auto&& arg : args)
{
if(arg->name() == "dequantizelinear")
{
auto q = arg->inputs().front();
if((q->name() == "quantizelinear") and
compare_literals(arg->inputs().at(1), q->inputs().at(1)) and
compare_literals(arg->inputs().at(2), q->inputs().at(2)))
{
instruction::replace_argument(ins, arg, q->inputs().front());
}
}
}
}
}
void simplify_qdq::apply(module& m) const
{
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -153,7 +153,8 @@ struct find_transpose
}
else
{
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
p.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
}
}
};
......@@ -278,10 +279,12 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
return p.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
auto t = p.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
......@@ -376,9 +379,7 @@ struct find_resize
return;
}
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
std::vector<int> index(out_shape.elements());
std::iota(index.begin(), index.end(), 0);
if(not std::all_of(index.begin(), index.end(), [&](auto i) {
if(not all_of(range(out_shape.elements()), [&](auto i) {
auto out_idx = out_shape.multi(i);
auto in_idx = out_idx;
std::transform(out_idx.begin(),
......@@ -420,7 +421,7 @@ struct find_resize
auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data);
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
......
......@@ -19,6 +19,7 @@ add_library(migraphx_cpu
logsoftmax.cpp
lowering.cpp
lrn.cpp
preallocate.cpp
pooling.cpp
reduction.cpp
reorder.cpp
......@@ -30,12 +31,29 @@ add_library(migraphx_cpu
set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION})
set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "")
find_package(Threads)
find_package(dnnl REQUIRED)
if(MIGRAPHX_ENABLE_ZENDNN)
find_path(ZENDNN_INC_PATH zendnn.hpp)
find_library(ZENDNN_LIB amdZenDNN)
find_library(BLIS_LIB blis)
else()
find_package(dnnl REQUIRED)
endif()
rocm_clang_tidy_check(migraphx_cpu)
if(MIGRAPHX_ENABLE_ZENDNN)
target_compile_definitions(migraphx_cpu PRIVATE -DMIGRAPHX_ENABLE_ZENDNN)
target_include_directories(migraphx_cpu PRIVATE ${ZENDNN_INC_PATH})
message(STATUS "ZENDNN_LIB: ${ZENDNN_LIB}")
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx Threads::Threads)
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
......
......@@ -11,6 +11,11 @@ operation cpu_allocation_model::allocate(const shape& s) const
return make_op(name(), {{"shape", to_value(s)}});
}
operation cpu_allocation_model::preallocate(const shape& s, const std::string& id) const
{
return make_op("cpu::preallocate", {{"shape", to_value(s)}, {"id", id}});
}
std::string cpu_allocation_model::copy() const { return "cpu::copy"; }
} // namespace cpu
......
......@@ -14,6 +14,8 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo")));
}
std::string group() const { return this->name() + "::" + algo; }
std::string name() const { return "dnnl::binary"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -35,7 +37,10 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
dnnl::binary::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC_0), m.at(DNNL_ARG_SRC_1), m.at(DNNL_ARG_DST)};
return {to_dnnl_algo(algo),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_1)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -11,7 +11,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_MULTIPLE_SRC);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC));
return result;
}
// Custom desc class since its missing in dnnl
......@@ -28,9 +28,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
for(auto i = 0; i < m.size() - 1; i++)
{
srcs.push_back(m.at(DNNL_ARG_MULTIPLE_SRC + i));
srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i));
}
return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), std::size_t(op.axis), srcs};
}
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
......@@ -15,7 +15,10 @@ namespace cpu {
struct dnnl_convolution
: dnnl_extend_op<dnnl_convolution, dnnl::convolution_forward, op::convolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -40,15 +43,18 @@ struct dnnl_convolution
auto dilation = op.dilation;
std::transform(
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(op.padding),
to_dnnl_dims(op.padding)};
to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)};
}
};
......
......@@ -9,7 +9,10 @@ namespace cpu {
struct dnnl_deconvolution
: dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::deconvolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -35,9 +38,9 @@ struct dnnl_deconvolution
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(op.padding),
......
......@@ -2,6 +2,9 @@
#if defined(__GNUC__) && __GNUC__ <= 5
namespace std {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#endif
template <>
struct hash<dnnl::algorithm>
{
......
......@@ -17,6 +17,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")));
}
std::string group() const { return this->name() + "::" + algo; }
std::string name() const { return "dnnl::eltwise"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -37,7 +39,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
{
return {dnnl::prop_kind::forward_inference,
to_dnnl_algo(algo),
m.at(DNNL_ARG_SRC_0),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
alpha,
beta};
}
......
......@@ -13,13 +13,20 @@ namespace cpu {
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC),
MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS),
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -14,6 +14,7 @@ struct cpu_allocation_model
std::string name() const;
std::string copy() const;
operation allocate(const shape& s) const;
operation preallocate(const shape& s, const std::string& id) const;
};
} // namespace cpu
......
......@@ -7,14 +7,26 @@
#include <migraphx/register_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <unordered_map>
#include <dnnl.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/assert.hpp>
#ifdef MIGRAPHX_ENABLE_ZENDNN
#include <zendnn.hpp>
#else
#include <dnnl.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#define MIGRAPHX_CONCAT_PREFIX(b) ZENDNN_##b // NOLINT
#else
#define MIGRAPHX_CONCAT_PREFIX(b) DNNL_##b // NOLINT
#endif
#define MIGRAPHX_DNNL_PREFIX(b) MIGRAPHX_CONCAT_PREFIX(b) // NOLINT
struct dnnl_context
{
dnnl::engine engine;
......@@ -74,6 +86,25 @@ struct dnnl_op : auto_register_op<Derived>
return reflect_base(self, f);
}
std::string group() const
{
const auto& self = static_cast<const Derived&>(*this);
return self.name();
}
value attributes() const
{
std::vector<std::string> names;
std::transform(post_ops.begin(), post_ops.end(), std::back_inserter(names), [](auto&& op) {
return op.algo;
});
const auto& self = static_cast<const Derived&>(*this);
auto g = self.group();
if(not names.empty())
g += "<" + join_strings(names, ",") + ">";
return {{"group", g}};
}
std::size_t get_extra_post_op_args() const
{
return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) {
......@@ -83,7 +114,8 @@ struct dnnl_op : auto_register_op<Derived>
static std::size_t get_binary_post_op_arg(std::size_t pos)
{
return DNNL_ARG_ATTR_MULTIPLE_POST_OP(pos) | DNNL_ARG_SRC_1; // NOLINT
return MIGRAPHX_DNNL_PREFIX(ARG_ATTR_MULTIPLE_POST_OP)(pos) | // NOLINT
MIGRAPHX_DNNL_PREFIX(ARG_SRC_1); // NOLINT
}
static std::vector<shape> to_shapes(const std::vector<argument>& args)
......@@ -98,14 +130,18 @@ struct dnnl_op : auto_register_op<Derived>
{
auto desc = prim.get_primitive_desc();
const char* str = nullptr;
#ifdef MIGRAPHX_ENABLE_ZENDNN
zendnn_primitive_desc_query(desc, zendnn_query_impl_info_str, 0, &str);
#else
dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str);
#endif
return str == nullptr ? "" : str;
}
// Map arg index to arg in dnnl
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result;
}
shape base_adjust_shape(const shape& s) const
......@@ -164,8 +200,9 @@ struct dnnl_op : auto_register_op<Derived>
{
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
......@@ -182,7 +219,7 @@ struct dnnl_op : auto_register_op<Derived>
if(contains(op.algo, "binary_add"))
{
auto desc = m.at(arg);
if(desc == m.at(DNNL_ARG_DST))
if(desc == m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)))
po.append_sum(1.0f);
else
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
......@@ -309,7 +346,8 @@ struct dnnl_op : auto_register_op<Derived>
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back());
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment