Commit d49d4f66 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents 7d248d46 4d82d761
......@@ -96,6 +96,11 @@ struct module
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
......@@ -110,6 +115,8 @@ struct module
instruction_ref add_return(std::vector<instruction_ref> args);
instruction_ref replace_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......
#ifndef MIGRAPHX_GUARD_OP_POINTWISE_HPP
#define MIGRAPHX_GUARD_OP_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pointwise
{
std::string name() const { return "pointwise"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
{
MIGRAPHX_THROW("should have one submodule.");
}
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
for(auto i : range(pnames.size()))
{
auto s1 = pm->get_parameter(pnames[i])->get_shape();
auto s2 = inputs[i];
if(s1.type() != s2.type())
MIGRAPHX_THROW("Mismatch type");
}
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type();
return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs));
}
argument compute(const shape& output_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
argument output{output_shape};
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
par_for(output_shape.elements(), [&](auto i) {
std::unordered_map<std::string, argument> params;
std::transform(
pnames.begin(),
pnames.end(),
args.begin(),
std::inserter(params, params.end()),
[&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); });
auto results = run(pm, params);
assert(results.size() == 1);
visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); });
});
return output;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_POINTWISE_HPP
......@@ -37,7 +37,7 @@ struct rnn_var_sl_shift_output
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
int64_t max_len = output_shape.lens()[0];
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
......@@ -76,7 +76,7 @@ struct rnn_var_sl_shift_sequence
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
int64_t max_len = output_shape.lens()[0];
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
......
......@@ -23,6 +23,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl;
struct marker;
/**
* @brief Stores the instruction stream
*/
......@@ -67,6 +69,8 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void mark(const parameter_map& params, marker&& m);
value to_value() const;
void from_value(const value& v);
......
......@@ -11,49 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -468,5 +468,11 @@ std::vector<shape> try_compute_shape(const operation& op, const std::vector<shap
}
return {new_shape};
}
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
return std::addressof(*ins);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <iterator>
#include <migraphx/module.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
......@@ -302,6 +303,55 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
}
instruction_ref module::add_literal(literal l)
{
impl->emplace_front(std::move(l));
......@@ -332,6 +382,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result;
}
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
// If there is no return then add a return
if(last->name() != "@return")
return this->add_return(args);
shape r = compute_shape(last->get_operator(), args);
instruction::replace(last, last->get_operator(), r, std::move(args));
assert(last->valid(begin()));
return last;
}
shape module::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
......
......@@ -20,7 +20,7 @@ auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<std::size_t>& lens)
{
std::vector<int64_t> result(vec);
int64_t n_rank = static_cast<int64_t>(lens.size());
int64_t n_rank = lens.size();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output))
{
......
......@@ -39,7 +39,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
int tuned_axis = tune_axis(n_rank, axis, opd.op_name);
auto axis_stride = data_s.strides()[tuned_axis];
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
int64_t data_elem_num = data_s.elements();
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = info.add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data);
......
......@@ -9,21 +9,20 @@ namespace onnx {
auto compute_type(shape::type_t t1, shape::type_t t2)
{
const static std::unordered_map<int, int> op_order = {
{static_cast<int>(shape::int8_type), 1},
{static_cast<int>(shape::uint8_type), 2},
{static_cast<int>(shape::int16_type), 3},
{static_cast<int>(shape::uint16_type), 4},
{static_cast<int>(shape::int32_type), 5},
{static_cast<int>(shape::uint32_type), 6},
{static_cast<int>(shape::int64_type), 7},
{static_cast<int>(shape::uint64_type), 8},
{static_cast<int>(shape::half_type), 9},
{static_cast<int>(shape::float_type), 10},
{static_cast<int>(shape::double_type), 11}};
const static std::unordered_map<int, int> op_order = {{shape::int8_type, 1},
{shape::uint8_type, 2},
{shape::int16_type, 3},
{shape::uint16_type, 4},
{shape::int32_type, 5},
{shape::uint32_type, 6},
{shape::int64_type, 7},
{shape::uint64_type, 8},
{shape::half_type, 9},
{shape::float_type, 10},
{shape::double_type, 11}};
int it1 = static_cast<int>(t1);
int it2 = static_cast<int>(t2);
int it1 = t1;
int it2 = t2;
if(!contains(op_order, it1) or !contains(op_order, it2))
{
MIGRAPHX_THROW("PARSE_POW: Input data type not supported!");
......
......@@ -334,7 +334,7 @@ struct parse_resize : op_parser<parse_resize>
auto ins_delta = info.add_literal(dim_s, delta_data);
// slice the data
int64_t slc_stride = static_cast<int64_t>(dim_lens[0]);
int64_t slc_stride = dim_lens[0];
auto low = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {slc_stride}}}),
data);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_spacetodepth : op_parser<parse_spacetodepth>
{
std::vector<op_desc> operators() const { return {{"SpaceToDepth"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// blocksize attribute of SpaceToDepth
int blocksize = 1; // if blockSize of 1 then, this is a no-op
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
// blockSize less than 1 would rather result in DepthToSpace instead of SpaceToDepth
MIGRAPHX_THROW("SpaceToDepth: blocksize is less than 1");
}
// calculate dimensions
auto res_lens = s.lens(); // {N, C, H, W}
if(((res_lens[2] % blocksize) == 0) and ((res_lens[3] % blocksize) == 0))
{
// Co = C * (blocksize ^ 2)
res_lens[1] = res_lens[1] * blocksize * blocksize;
// Ho = (H / blocksize)
res_lens[2] = res_lens[2] / blocksize;
// Wo = (W / blocksize)
res_lens[3] = res_lens[3] / blocksize;
} // res_shape = (N, Co, Ho, Wo)
else
MIGRAPHX_THROW("SpaceToDepth: div by blocksize quotient not int ");
auto trans_lens = s.lens(); // {N, C, H, W}
trans_lens[2] = res_lens[2];
trans_lens[3] = blocksize;
trans_lens.push_back(res_lens[3]);
trans_lens.push_back(blocksize); // {N, C, Ho, blocksize, Wo, blocksize}
std::vector<int64_t> perm = {0, 3, 5, 1, 2, 4};
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", trans_lens}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", res_lens}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -24,7 +24,7 @@ struct parse_split : op_parser<parse_split>
}
auto lens = args[0]->get_shape().lens();
int64_t n_rank = static_cast<int64_t>(lens.size());
int64_t n_rank = lens.size();
int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
std::vector<int64_t> vec_splits;
......
......@@ -13,6 +13,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -507,6 +508,24 @@ std::string perf_group(const operation& op)
return op.name();
}
void program::mark(const parameter_map& params, marker&& m)
{
auto& ctx = this->impl->ctx;
// Run once by itself
eval(params);
ctx.finish();
// Start marking
m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result;
m.mark_start(ins);
result = f();
m.mark_stop(ins);
return result;
}));
m.mark_stop(*this);
}
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
{
auto& ctx = this->impl->ctx;
......
......@@ -269,7 +269,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
long seq_len = get_seq_len(prog, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......@@ -556,7 +556,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long hs = static_cast<long>(r_shape.lens()[2]);
long hs = r_shape.lens()[2];
migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(ss.elements(), 1.0f);
......@@ -613,7 +613,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
rb_h);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
long seq_len = get_seq_len(prog, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......@@ -1032,7 +1032,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref last_cell_output{};
migraphx::shape r_shape = r->get_shape();
long hs = static_cast<long>(r_shape.lens()[2]);
long hs = r_shape.lens()[2];
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0};
......@@ -1094,7 +1094,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
long seq_len = get_seq_len(prog, seq, seq_lens);
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......
File mode changed from 100644 to 100755
......@@ -33,8 +33,6 @@ rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION})
set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "")
find_package(Threads)
if(MIGRAPHX_ENABLE_ZENDNN)
find_path(ZENDNN_INC_PATH zendnn.hpp)
find_library(ZENDNN_LIB amdZenDNN)
......@@ -53,7 +51,7 @@ if(MIGRAPHX_ENABLE_ZENDNN)
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx Threads::Threads)
target_link_libraries(migraphx_cpu PRIVATE migraphx)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
......
......@@ -45,7 +45,7 @@ TEST_CASE(if_pl_test)
auto ys = param_shapes["y"];
std::vector<float> yd(ys.bytes() / sizeof(float), 2.0);
pp.add("y", migraphx::argument(ys, yd.data()));
char ccond = static_cast<char>(cond);
char ccond = cond;
pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond));
auto outputs = p.eval(pp);
......
......@@ -8,16 +8,22 @@ TEST_CASE(add_op)
EXPECT(add_op.name() == "add");
}
TEST_CASE(reduce_mean)
TEST_CASE(reduce_mean_without_quotes)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean1)
TEST_CASE(reduce_mean)
{
auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean");
}
TEST_CASE(reduce_mean_with_format)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4);
EXPECT(rm.name() == "reduce_mean");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "migraphx/dead_code_elimination.hpp"
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(double_add)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(used_twice_not_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add3 = mm->add_instruction(migraphx::make_op("add"), pass, add2);
mm->add_return({add3});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd =
add_pointwise(p2, "pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
});
mm->add_return({fadd});
}
EXPECT(p1 == p2);
}
TEST_CASE(used_twice_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x);
auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3);
mm->add_return({add4});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add2, add3);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(duplicate_inputs)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, x);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, y);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
});
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(scalar_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto y =
mm->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), one);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment